Fixed #31685 -- Added support for updating conflicts to QuerySet.bulk_create().

Thanks Florian Apolloner, Chris Jerdonek, Hannes Ljungberg, Nick Pope,
and Mariusz Felisiak for reviews.
This commit is contained in:
sean_c_hsu 2020-06-15 00:58:06 +08:00 committed by Mariusz Felisiak
parent ba9de2e74e
commit 0f6946495a
16 changed files with 542 additions and 43 deletions

View File

@ -271,6 +271,10 @@ class BaseDatabaseFeatures:
# Does the backend support ignoring constraint or uniqueness errors during
# INSERT?
supports_ignore_conflicts = True
# Does the backend support updating rows on constraint or uniqueness errors
# during INSERT?
supports_update_conflicts = False
supports_update_conflicts_with_target = False
# Does this backend require casting the results of CASE expressions used
# in UPDATE statements to ensure the expression has the correct type?

View File

@ -717,8 +717,8 @@ class BaseDatabaseOperations:
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
return self.explain_prefix
def insert_statement(self, ignore_conflicts=False):
def insert_statement(self, on_conflict=None):
return 'INSERT INTO'
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
return ''

View File

@ -24,6 +24,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_select_difference = False
supports_slicing_ordering_in_compound = True
supports_index_on_text_field = False
supports_update_conflicts = True
create_test_procedure_without_params_sql = """
CREATE PROCEDURE test_procedure ()
BEGIN

View File

@ -4,6 +4,7 @@ from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str
@ -365,8 +366,10 @@ class DatabaseOperations(BaseDatabaseOperations):
match_option = 'c' if lookup_type == 'regex' else 'i'
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
def insert_statement(self, ignore_conflicts=False):
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return 'INSERT IGNORE INTO'
return super().insert_statement(on_conflict=on_conflict)
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
@ -388,3 +391,27 @@ class DatabaseOperations(BaseDatabaseOperations):
if getattr(expression, 'conditional', False):
return False
return super().conditional_expression_supported_in_where_clause(expression)
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.UPDATE:
conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s'
field_sql = '%(field)s = VALUES(%(field)s)'
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
# aliases for the new row and its columns available in MySQL
# 8.0.19+.
if not self.connection.mysql_is_mariadb:
if self.connection.mysql_version >= (8, 0, 19):
conflict_suffix_sql = f'AS new {conflict_suffix_sql}'
field_sql = '%(field)s = new.%(field)s'
# VALUES() was renamed to VALUE() in MariaDB 10.3.3+.
elif self.connection.mysql_version >= (10, 3, 3):
field_sql = '%(field)s = VALUE(%(field)s)'
fields = ', '.join([
field_sql % {'field': field}
for field in map(self.quote_name, update_fields)
])
return conflict_suffix_sql % {'fields': fields}
return super().on_conflict_suffix_sql(
fields, on_conflict, update_fields, unique_fields,
)

View File

@ -57,6 +57,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_deferrable_unique_constraints = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
supports_update_conflicts = True
supports_update_conflicts_with_target = True
test_collations = {
'non_default': 'sv-x-icu',
'swedish_ci': 'sv-x-icu',

View File

@ -3,6 +3,7 @@ from psycopg2.extras import Inet
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
class DatabaseOperations(BaseDatabaseOperations):
@ -272,5 +273,17 @@ class DatabaseOperations(BaseDatabaseOperations):
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
return prefix
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if on_conflict == OnConflict.IGNORE:
return 'ON CONFLICT DO NOTHING'
if on_conflict == OnConflict.UPDATE:
return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
', '.join(map(self.quote_name, unique_fields)),
', '.join([
f'{field} = EXCLUDED.{field}'
for field in map(self.quote_name, update_fields)
]),
)
return super().on_conflict_suffix_sql(
fields, on_conflict, update_fields, unique_fields,
)

View File

@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0)
order_by_nulls_first = True
supports_json_field_contains = False
supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0)
supports_update_conflicts_with_target = supports_update_conflicts
test_collations = {
'ci': 'nocase',
'cs': 'binary',

View File

@ -8,6 +8,7 @@ from django.conf import settings
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.constants import OnConflict
from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
@ -370,8 +371,10 @@ class DatabaseOperations(BaseDatabaseOperations):
return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params
return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params
def insert_statement(self, ignore_conflicts=False):
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
def insert_statement(self, on_conflict=None):
if on_conflict == OnConflict.IGNORE:
return 'INSERT OR IGNORE INTO'
return super().insert_statement(on_conflict=on_conflict)
def return_insert_columns(self, fields):
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
@ -384,3 +387,19 @@ class DatabaseOperations(BaseDatabaseOperations):
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
if (
on_conflict == OnConflict.UPDATE and
self.connection.features.supports_update_conflicts_with_target
):
return 'ON CONFLICT(%s) DO UPDATE SET %s' % (
', '.join(map(self.quote_name, unique_fields)),
', '.join([
f'{field} = EXCLUDED.{field}'
for field in map(self.quote_name, update_fields)
]),
)
return super().on_conflict_suffix_sql(
fields, on_conflict, update_fields, unique_fields,
)

View File

@ -1,6 +1,12 @@
"""
Constants used across the ORM in general.
"""
from enum import Enum
# Separator used to split filter strings apart.
LOOKUP_SEP = '__'
class OnConflict(Enum):
IGNORE = 'ignore'
UPDATE = 'update'

View File

@ -15,7 +15,7 @@ from django.db import (
router, transaction,
)
from django.db.models import AutoField, DateField, DateTimeField, sql
from django.db.models.constants import LOOKUP_SEP
from django.db.models.constants import LOOKUP_SEP, OnConflict
from django.db.models.deletion import Collector
from django.db.models.expressions import Case, Expression, F, Ref, Value, When
from django.db.models.functions import Cast, Trunc
@ -466,7 +466,69 @@ class QuerySet:
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
obj._prepare_related_fields_for_save(operation_name='bulk_create')
def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields):
if ignore_conflicts and update_conflicts:
raise ValueError(
'ignore_conflicts and update_conflicts are mutually exclusive.'
)
db_features = connections[self.db].features
if ignore_conflicts:
if not db_features.supports_ignore_conflicts:
raise NotSupportedError(
'This database backend does not support ignoring conflicts.'
)
return OnConflict.IGNORE
elif update_conflicts:
if not db_features.supports_update_conflicts:
raise NotSupportedError(
'This database backend does not support updating conflicts.'
)
if not update_fields:
raise ValueError(
'Fields that will be updated when a row insertion fails '
'on conflicts must be provided.'
)
if unique_fields and not db_features.supports_update_conflicts_with_target:
raise NotSupportedError(
'This database backend does not support updating '
'conflicts with specifying unique fields that can trigger '
'the upsert.'
)
if not unique_fields and db_features.supports_update_conflicts_with_target:
raise ValueError(
'Unique fields that can trigger the upsert must be '
'provided.'
)
# Updating primary keys and non-concrete fields is forbidden.
update_fields = [self.model._meta.get_field(name) for name in update_fields]
if any(not f.concrete or f.many_to_many for f in update_fields):
raise ValueError(
'bulk_create() can only be used with concrete fields in '
'update_fields.'
)
if any(f.primary_key for f in update_fields):
raise ValueError(
'bulk_create() cannot be used with primary keys in '
'update_fields.'
)
if unique_fields:
# Primary key is allowed in unique_fields.
unique_fields = [
self.model._meta.get_field(name)
for name in unique_fields if name != 'pk'
]
if any(not f.concrete or f.many_to_many for f in unique_fields):
raise ValueError(
'bulk_create() can only be used with concrete fields '
'in unique_fields.'
)
return OnConflict.UPDATE
return None
def bulk_create(
self, objs, batch_size=None, ignore_conflicts=False,
update_conflicts=False, update_fields=None, unique_fields=None,
):
"""
Insert each of the instances into the database. Do *not* call
save() on each of the instances, do not send any pre/post_save
@ -497,6 +559,12 @@ class QuerySet:
raise ValueError("Can't bulk create a multi-table inherited model")
if not objs:
return objs
on_conflict = self._check_bulk_create_options(
ignore_conflicts,
update_conflicts,
update_fields,
unique_fields,
)
self._for_write = True
opts = self.model._meta
fields = opts.concrete_fields
@ -506,7 +574,12 @@ class QuerySet:
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk:
returned_columns = self._batched_insert(
objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
objs_with_pk,
fields,
batch_size,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
for obj_with_pk, results in zip(objs_with_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
@ -518,10 +591,15 @@ class QuerySet:
if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)]
returned_columns = self._batched_insert(
objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
objs_without_pk,
fields,
batch_size,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
connection = connections[self.db]
if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
if connection.features.can_return_rows_from_bulk_insert and on_conflict is None:
assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, results in zip(objs_without_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
@ -1293,7 +1371,10 @@ class QuerySet:
# PRIVATE METHODS #
###################
def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
def _insert(
self, objs, fields, returning_fields=None, raw=False, using=None,
on_conflict=None, update_fields=None, unique_fields=None,
):
"""
Insert a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented.
@ -1301,33 +1382,45 @@ class QuerySet:
self._for_write = True
if using is None:
using = self.db
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
query = sql.InsertQuery(
self.model,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(returning_fields)
_insert.alters_data = True
_insert.queryset_only = False
def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False):
def _batched_insert(
self, objs, fields, batch_size, on_conflict=None, update_fields=None,
unique_fields=None,
):
"""
Helper method for bulk_create() to insert objs one batch at a time.
"""
connection = connections[self.db]
if ignore_conflicts and not connection.features.supports_ignore_conflicts:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
ops = connection.ops
max_batch_size = max(ops.bulk_batch_size(fields, objs), 1)
batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size
inserted_rows = []
bulk_return = connection.features.can_return_rows_from_bulk_insert
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and not ignore_conflicts:
if bulk_return and on_conflict is None:
inserted_rows.extend(self._insert(
item, fields=fields, using=self.db,
returning_fields=self.model._meta.db_returning_fields,
ignore_conflicts=ignore_conflicts,
))
else:
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
self._insert(
item,
fields=fields,
using=self.db,
on_conflict=on_conflict,
update_fields=update_fields,
unique_fields=unique_fields,
)
return inserted_rows
def _chain(self):

View File

@ -1387,7 +1387,9 @@ class SQLInsertCompiler(SQLCompiler):
# going to be column names (so we can avoid the extra overhead).
qn = self.connection.ops.quote_name
opts = self.query.get_meta()
insert_statement = self.connection.ops.insert_statement(ignore_conflicts=self.query.ignore_conflicts)
insert_statement = self.connection.ops.insert_statement(
on_conflict=self.query.on_conflict,
)
result = ['%s %s' % (insert_statement, qn(opts.db_table))]
fields = self.query.fields or [opts.pk]
result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
@ -1410,8 +1412,11 @@ class SQLInsertCompiler(SQLCompiler):
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql(
ignore_conflicts=self.query.ignore_conflicts
on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql(
fields,
self.query.on_conflict,
self.query.update_fields,
self.query.unique_fields,
)
if self.returning_fields and self.connection.features.can_return_columns_from_insert:
if self.connection.features.can_return_rows_from_bulk_insert:
@ -1420,8 +1425,8 @@ class SQLInsertCompiler(SQLCompiler):
else:
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
params = [param_rows[0]]
if ignore_conflicts_suffix_sql:
result.append(ignore_conflicts_suffix_sql)
if on_conflict_suffix_sql:
result.append(on_conflict_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
@ -1432,12 +1437,12 @@ class SQLInsertCompiler(SQLCompiler):
if can_bulk:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
if ignore_conflicts_suffix_sql:
result.append(ignore_conflicts_suffix_sql)
if on_conflict_suffix_sql:
result.append(on_conflict_suffix_sql)
return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
else:
if ignore_conflicts_suffix_sql:
result.append(ignore_conflicts_suffix_sql)
if on_conflict_suffix_sql:
result.append(on_conflict_suffix_sql)
return [
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
for p, vals in zip(placeholder_rows, param_rows)

View File

@ -138,11 +138,13 @@ class UpdateQuery(Query):
class InsertQuery(Query):
compiler = 'SQLInsertCompiler'
def __init__(self, *args, ignore_conflicts=False, **kwargs):
def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs):
super().__init__(*args, **kwargs)
self.fields = []
self.objs = []
self.ignore_conflicts = ignore_conflicts
self.on_conflict = on_conflict
self.update_fields = update_fields or []
self.unique_fields = unique_fields or []
def insert_values(self, fields, objs, raw=False):
self.fields = fields

View File

@ -2155,7 +2155,7 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised.
``bulk_create()``
~~~~~~~~~~~~~~~~~
.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False)
.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None)
This method inserts the provided list of objects into the database in an
efficient manner (generally only 1 query, no matter how many objects there
@ -2198,9 +2198,17 @@ where the default is such that at most 999 variables per query are used.
On databases that support it (all but Oracle), setting the ``ignore_conflicts``
parameter to ``True`` tells the database to ignore failure to insert any rows
that fail constraints such as duplicate unique values. Enabling this parameter
disables setting the primary key on each model instance (if the database
normally supports it).
that fail constraints such as duplicate unique values.
On databases that support it (all except Oracle and SQLite < 3.24), setting the
``update_conflicts`` parameter to ``True``, tells the database to update
``update_fields`` when a row insertion fails on conflicts. On PostgreSQL and
SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may
be in conflict must be provided.
Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable
setting the primary key on each model instance (if the database normally
support it).
.. warning::
@ -2217,6 +2225,12 @@ normally supports it).
Support for the fetching primary key attributes on SQLite 3.35+ was added.
.. versionchanged:: 4.1
The ``update_conflicts``, ``update_fields``, and ``unique_fields``
parameters were added to support updating fields when a row insertion fails
on conflict.
``bulk_update()``
~~~~~~~~~~~~~~~~~

View File

@ -232,6 +232,10 @@ Models
in order to reduce the number of failed requests, e.g. after database server
restart.
* :meth:`.QuerySet.bulk_create` now supports updating fields when a row
insertion fails uniqueness constraints. This is supported on MariaDB, MySQL,
PostgreSQL, and SQLite 3.24+.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~
@ -298,6 +302,14 @@ backends.
* ``DatabaseIntrospection.get_key_columns()`` is removed. Use
``DatabaseIntrospection.get_relations()`` instead.
* ``DatabaseOperations.ignore_conflicts_suffix_sql()`` method is replaced by
``DatabaseOperations.on_conflict_suffix_sql()`` that accepts the ``fields``,
``on_conflict``, ``update_fields``, and ``unique_fields`` arguments.
* The ``ignore_conflicts`` argument of the
``DatabaseOperations.insert_statement()`` method is replaced by
``on_conflict`` that accepts ``django.db.models.constants.OnConflict``.
Dropped support for MariaDB 10.2
--------------------------------

View File

@ -16,6 +16,14 @@ class Country(models.Model):
iso_two_letter = models.CharField(max_length=2)
description = models.TextField()
class Meta:
constraints = [
models.UniqueConstraint(
fields=['iso_two_letter', 'name'],
name='country_name_iso_unique',
),
]
class ProxyCountry(Country):
class Meta:
@ -58,6 +66,13 @@ class State(models.Model):
class TwoFields(models.Model):
f1 = models.IntegerField(unique=True)
f2 = models.IntegerField(unique=True)
name = models.CharField(max_length=15, null=True)
class UpsertConflict(models.Model):
number = models.IntegerField(unique=True)
rank = models.IntegerField()
name = models.CharField(max_length=15)
class NoFields(models.Model):
@ -103,3 +118,9 @@ class NullableFields(models.Model):
text_field = models.TextField(null=True, default='text')
url_field = models.URLField(null=True, default='/')
uuid_field = models.UUIDField(null=True, default=uuid.uuid4)
class RelatedModel(models.Model):
name = models.CharField(max_length=15, null=True)
country = models.OneToOneField(Country, models.CASCADE, primary_key=True)
big_auto_fields = models.ManyToManyField(BigAutoFieldModel)

View File

@ -1,7 +1,11 @@
from math import ceil
from operator import attrgetter
from django.db import IntegrityError, NotSupportedError, connection
from django.core.exceptions import FieldDoesNotExist
from django.db import (
IntegrityError, NotSupportedError, OperationalError, ProgrammingError,
connection,
)
from django.db.models import FileField, Value
from django.db.models.functions import Lower
from django.test import (
@ -11,7 +15,8 @@ from django.test import (
from .models import (
BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria,
ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry,
Restaurant, SmallAutoFieldModel, State, TwoFields,
RelatedModel, Restaurant, SmallAutoFieldModel, State, TwoFields,
UpsertConflict,
)
@ -53,10 +58,10 @@ class BulkCreateTests(TestCase):
@skipUnlessDBFeature('has_bulk_insert')
def test_long_and_short_text(self):
Country.objects.bulk_create([
Country(description='a' * 4001),
Country(description='a'),
Country(description='Ж' * 2001),
Country(description='Ж'),
Country(description='a' * 4001, iso_two_letter='A'),
Country(description='a', iso_two_letter='B'),
Country(description='Ж' * 2001, iso_two_letter='C'),
Country(description='Ж', iso_two_letter='D'),
])
self.assertEqual(Country.objects.count(), 4)
@ -218,7 +223,7 @@ class BulkCreateTests(TestCase):
@skipUnlessDBFeature('has_bulk_insert')
def test_explicit_batch_size_respects_max_batch_size(self):
objs = [Country() for i in range(1000)]
objs = [Country(name=f'Country {i}') for i in range(1000)]
fields = ['name', 'iso_two_letter', 'description']
max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1)
with self.assertNumQueries(ceil(len(objs) / max_batch_size)):
@ -352,3 +357,276 @@ class BulkCreateTests(TestCase):
msg = 'Batch size must be a positive integer.'
with self.assertRaisesMessage(ValueError, msg):
Country.objects.bulk_create([], batch_size=-1)
@skipIfDBFeature('supports_update_conflicts')
def test_update_conflicts_unsupported(self):
msg = 'This database backend does not support updating conflicts.'
with self.assertRaisesMessage(NotSupportedError, msg):
Country.objects.bulk_create(self.data, update_conflicts=True)
@skipUnlessDBFeature('supports_ignore_conflicts', 'supports_update_conflicts')
def test_ignore_update_conflicts_exclusive(self):
msg = 'ignore_conflicts and update_conflicts are mutually exclusive'
with self.assertRaisesMessage(ValueError, msg):
Country.objects.bulk_create(
self.data,
ignore_conflicts=True,
update_conflicts=True,
)
@skipUnlessDBFeature('supports_update_conflicts')
def test_update_conflicts_no_update_fields(self):
msg = (
'Fields that will be updated when a row insertion fails on '
'conflicts must be provided.'
)
with self.assertRaisesMessage(ValueError, msg):
Country.objects.bulk_create(self.data, update_conflicts=True)
@skipUnlessDBFeature('supports_update_conflicts')
@skipIfDBFeature('supports_update_conflicts_with_target')
def test_update_conflicts_unique_field_unsupported(self):
msg = (
'This database backend does not support updating conflicts with '
'specifying unique fields that can trigger the upsert.'
)
with self.assertRaisesMessage(NotSupportedError, msg):
TwoFields.objects.bulk_create(
[TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
update_conflicts=True,
update_fields=['f2'],
unique_fields=['f1'],
)
@skipUnlessDBFeature('supports_update_conflicts')
def test_update_conflicts_nonexistent_update_fields(self):
unique_fields = None
if connection.features.supports_update_conflicts_with_target:
unique_fields = ['f1']
msg = "TwoFields has no field named 'nonexistent'"
with self.assertRaisesMessage(FieldDoesNotExist, msg):
TwoFields.objects.bulk_create(
[TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
update_conflicts=True,
update_fields=['nonexistent'],
unique_fields=unique_fields,
)
@skipUnlessDBFeature(
'supports_update_conflicts', 'supports_update_conflicts_with_target',
)
def test_update_conflicts_unique_fields_required(self):
msg = 'Unique fields that can trigger the upsert must be provided.'
with self.assertRaisesMessage(ValueError, msg):
TwoFields.objects.bulk_create(
[TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)],
update_conflicts=True,
update_fields=['f1'],
)
@skipUnlessDBFeature(
'supports_update_conflicts', 'supports_update_conflicts_with_target',
)
def test_update_conflicts_invalid_update_fields(self):
msg = (
'bulk_create() can only be used with concrete fields in '
'update_fields.'
)
# Reverse one-to-one relationship.
with self.assertRaisesMessage(ValueError, msg):
Country.objects.bulk_create(
self.data,
update_conflicts=True,
update_fields=['relatedmodel'],
unique_fields=['pk'],
)
# Many-to-many relationship.
with self.assertRaisesMessage(ValueError, msg):
RelatedModel.objects.bulk_create(
[RelatedModel(country=self.data[0])],
update_conflicts=True,
update_fields=['big_auto_fields'],
unique_fields=['country'],
)
@skipUnlessDBFeature(
'supports_update_conflicts', 'supports_update_conflicts_with_target',
)
def test_update_conflicts_pk_in_update_fields(self):
msg = 'bulk_create() cannot be used with primary keys in update_fields.'
with self.assertRaisesMessage(ValueError, msg):
BigAutoFieldModel.objects.bulk_create(
[BigAutoFieldModel()],
update_conflicts=True,
update_fields=['id'],
unique_fields=['id'],
)
@skipUnlessDBFeature(
'supports_update_conflicts', 'supports_update_conflicts_with_target',
)
def test_update_conflicts_invalid_unique_fields(self):
msg = (
'bulk_create() can only be used with concrete fields in '
'unique_fields.'
)
# Reverse one-to-one relationship.
with self.assertRaisesMessage(ValueError, msg):
Country.objects.bulk_create(
self.data,
update_conflicts=True,
update_fields=['name'],
unique_fields=['relatedmodel'],
)
# Many-to-many relationship.
with self.assertRaisesMessage(ValueError, msg):
RelatedModel.objects.bulk_create(
[RelatedModel(country=self.data[0])],
update_conflicts=True,
update_fields=['name'],
unique_fields=['big_auto_fields'],
)
def _test_update_conflicts_two_fields(self, unique_fields):
TwoFields.objects.bulk_create([
TwoFields(f1=1, f2=1, name='a'),
TwoFields(f1=2, f2=2, name='b'),
])
self.assertEqual(TwoFields.objects.count(), 2)
conflicting_objects = [
TwoFields(f1=1, f2=1, name='c'),
TwoFields(f1=2, f2=2, name='d'),
]
TwoFields.objects.bulk_create(
conflicting_objects,
update_conflicts=True,
unique_fields=unique_fields,
update_fields=['name'],
)
self.assertEqual(TwoFields.objects.count(), 2)
self.assertCountEqual(TwoFields.objects.values('f1', 'f2', 'name'), [
{'f1': 1, 'f2': 1, 'name': 'c'},
{'f1': 2, 'f2': 2, 'name': 'd'},
])
@skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
def test_update_conflicts_two_fields_unique_fields_first(self):
self._test_update_conflicts_two_fields(['f1'])
@skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
def test_update_conflicts_two_fields_unique_fields_second(self):
self._test_update_conflicts_two_fields(['f2'])
@skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
def test_update_conflicts_two_fields_unique_fields_both(self):
with self.assertRaises((OperationalError, ProgrammingError)):
self._test_update_conflicts_two_fields(['f1', 'f2'])
@skipUnlessDBFeature('supports_update_conflicts')
@skipIfDBFeature('supports_update_conflicts_with_target')
def test_update_conflicts_two_fields_no_unique_fields(self):
self._test_update_conflicts_two_fields([])
def _test_update_conflicts_unique_two_fields(self, unique_fields):
Country.objects.bulk_create(self.data)
self.assertEqual(Country.objects.count(), 4)
new_data = [
# Conflicting countries.
Country(name='Germany', iso_two_letter='DE', description=(
'Germany is a country in Central Europe.'
)),
Country(name='Czech Republic', iso_two_letter='CZ', description=(
'The Czech Republic is a landlocked country in Central Europe.'
)),
# New countries.
Country(name='Australia', iso_two_letter='AU'),
Country(name='Japan', iso_two_letter='JP', description=(
'Japan is an island country in East Asia.'
)),
]
Country.objects.bulk_create(
new_data,
update_conflicts=True,
update_fields=['description'],
unique_fields=unique_fields,
)
self.assertEqual(Country.objects.count(), 6)
self.assertCountEqual(Country.objects.values('iso_two_letter', 'description'), [
{'iso_two_letter': 'US', 'description': ''},
{'iso_two_letter': 'NL', 'description': ''},
{'iso_two_letter': 'DE', 'description': (
'Germany is a country in Central Europe.'
)},
{'iso_two_letter': 'CZ', 'description': (
'The Czech Republic is a landlocked country in Central Europe.'
)},
{'iso_two_letter': 'AU', 'description': ''},
{'iso_two_letter': 'JP', 'description': (
'Japan is an island country in East Asia.'
)},
])
@skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
def test_update_conflicts_unique_two_fields_unique_fields_both(self):
self._test_update_conflicts_unique_two_fields(['iso_two_letter', 'name'])
@skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
def test_update_conflicts_unique_two_fields_unique_fields_one(self):
with self.assertRaises((OperationalError, ProgrammingError)):
self._test_update_conflicts_unique_two_fields(['iso_two_letter'])
@skipUnlessDBFeature('supports_update_conflicts')
@skipIfDBFeature('supports_update_conflicts_with_target')
def test_update_conflicts_unique_two_fields_unique_no_unique_fields(self):
self._test_update_conflicts_unique_two_fields([])
def _test_update_conflicts(self, unique_fields):
UpsertConflict.objects.bulk_create([
UpsertConflict(number=1, rank=1, name='John'),
UpsertConflict(number=2, rank=2, name='Mary'),
UpsertConflict(number=3, rank=3, name='Hannah'),
])
self.assertEqual(UpsertConflict.objects.count(), 3)
conflicting_objects = [
UpsertConflict(number=1, rank=4, name='Steve'),
UpsertConflict(number=2, rank=2, name='Olivia'),
UpsertConflict(number=3, rank=1, name='Hannah'),
]
UpsertConflict.objects.bulk_create(
conflicting_objects,
update_conflicts=True,
update_fields=['name', 'rank'],
unique_fields=unique_fields,
)
self.assertEqual(UpsertConflict.objects.count(), 3)
self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [
{'number': 1, 'rank': 4, 'name': 'Steve'},
{'number': 2, 'rank': 2, 'name': 'Olivia'},
{'number': 3, 'rank': 1, 'name': 'Hannah'},
])
UpsertConflict.objects.bulk_create(
conflicting_objects + [UpsertConflict(number=4, rank=4, name='Mark')],
update_conflicts=True,
update_fields=['name', 'rank'],
unique_fields=unique_fields,
)
self.assertEqual(UpsertConflict.objects.count(), 4)
self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [
{'number': 1, 'rank': 4, 'name': 'Steve'},
{'number': 2, 'rank': 2, 'name': 'Olivia'},
{'number': 3, 'rank': 1, 'name': 'Hannah'},
{'number': 4, 'rank': 4, 'name': 'Mark'},
])
@skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target')
def test_update_conflicts_unique_fields(self):
self._test_update_conflicts(unique_fields=['number'])
@skipUnlessDBFeature('supports_update_conflicts')
@skipIfDBFeature('supports_update_conflicts_with_target')
def test_update_conflicts_no_unique_fields(self):
self._test_update_conflicts([])