diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 8654293146..20d5c8f772 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -271,6 +271,10 @@ class BaseDatabaseFeatures: # Does the backend support ignoring constraint or uniqueness errors during # INSERT? supports_ignore_conflicts = True + # Does the backend support updating rows on constraint or uniqueness errors + # during INSERT? + supports_update_conflicts = False + supports_update_conflicts_with_target = False # Does this backend require casting the results of CASE expressions used # in UPDATE statements to ensure the expression has the correct type? diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 66bb009175..7422137304 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -717,8 +717,8 @@ class BaseDatabaseOperations: raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys()))) return self.explain_prefix - def insert_statement(self, ignore_conflicts=False): + def insert_statement(self, on_conflict=None): return 'INSERT INTO' - def ignore_conflicts_suffix_sql(self, ignore_conflicts=None): + def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): return '' diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index bbb8f6d79b..5d6c4afde0 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -24,6 +24,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_select_difference = False supports_slicing_ordering_in_compound = True supports_index_on_text_field = False + supports_update_conflicts = True create_test_procedure_without_params_sql = """ CREATE PROCEDURE test_procedure () BEGIN diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index c878664a5c..923e50a8d4 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -4,6 +4,7 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta from django.db.models import Exists, ExpressionWrapper, Lookup +from django.db.models.constants import OnConflict from django.utils import timezone from django.utils.encoding import force_str @@ -365,8 +366,10 @@ class DatabaseOperations(BaseDatabaseOperations): match_option = 'c' if lookup_type == 'regex' else 'i' return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option - def insert_statement(self, ignore_conflicts=False): - return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts) + def insert_statement(self, on_conflict=None): + if on_conflict == OnConflict.IGNORE: + return 'INSERT IGNORE INTO' + return super().insert_statement(on_conflict=on_conflict) def lookup_cast(self, lookup_type, internal_type=None): lookup = '%s' @@ -388,3 +391,27 @@ class DatabaseOperations(BaseDatabaseOperations): if getattr(expression, 'conditional', False): return False return super().conditional_expression_supported_in_where_clause(expression) + + def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): + if on_conflict == OnConflict.UPDATE: + conflict_suffix_sql = 'ON DUPLICATE KEY UPDATE %(fields)s' + field_sql = '%(field)s = VALUES(%(field)s)' + # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use + # aliases for the new row and its columns available in MySQL + # 8.0.19+. + if not self.connection.mysql_is_mariadb: + if self.connection.mysql_version >= (8, 0, 19): + conflict_suffix_sql = f'AS new {conflict_suffix_sql}' + field_sql = '%(field)s = new.%(field)s' + # VALUES() was renamed to VALUE() in MariaDB 10.3.3+. + elif self.connection.mysql_version >= (10, 3, 3): + field_sql = '%(field)s = VALUE(%(field)s)' + + fields = ', '.join([ + field_sql % {'field': field} + for field in map(self.quote_name, update_fields) + ]) + return conflict_suffix_sql % {'fields': fields} + return super().on_conflict_suffix_sql( + fields, on_conflict, update_fields, unique_fields, + ) diff --git a/django/db/backends/postgresql/features.py b/django/db/backends/postgresql/features.py index 097d41a45b..1ce73fb0a8 100644 --- a/django/db/backends/postgresql/features.py +++ b/django/db/backends/postgresql/features.py @@ -57,6 +57,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_deferrable_unique_constraints = True has_json_operators = True json_key_contains_list_matching_requires_list = True + supports_update_conflicts = True + supports_update_conflicts_with_target = True test_collations = { 'non_default': 'sv-x-icu', 'swedish_ci': 'sv-x-icu', diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 399c1b24e7..762cd8d23e 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -3,6 +3,7 @@ from psycopg2.extras import Inet from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta +from django.db.models.constants import OnConflict class DatabaseOperations(BaseDatabaseOperations): @@ -272,5 +273,17 @@ class DatabaseOperations(BaseDatabaseOperations): prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items()) return prefix - def ignore_conflicts_suffix_sql(self, ignore_conflicts=None): - return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts) + def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): + if on_conflict == OnConflict.IGNORE: + return 'ON CONFLICT DO NOTHING' + if on_conflict == OnConflict.UPDATE: + return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( + ', '.join(map(self.quote_name, unique_fields)), + ', '.join([ + f'{field} = EXCLUDED.{field}' + for field in map(self.quote_name, update_fields) + ]), + ) + return super().on_conflict_suffix_sql( + fields, on_conflict, update_fields, unique_fields, + ) diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index ad35574463..153ce8d1d1 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -40,6 +40,8 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_order_by_nulls_modifier = Database.sqlite_version_info >= (3, 30, 0) order_by_nulls_first = True supports_json_field_contains = False + supports_update_conflicts = Database.sqlite_version_info >= (3, 24, 0) + supports_update_conflicts_with_target = supports_update_conflicts test_collations = { 'ci': 'nocase', 'cs': 'binary', diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 90a4241803..34a7251eea 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -8,6 +8,7 @@ from django.conf import settings from django.core.exceptions import FieldError from django.db import DatabaseError, NotSupportedError, models from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.models.constants import OnConflict from django.db.models.expressions import Col from django.utils import timezone from django.utils.dateparse import parse_date, parse_datetime, parse_time @@ -370,8 +371,10 @@ class DatabaseOperations(BaseDatabaseOperations): return 'django_time_diff(%s, %s)' % (lhs_sql, rhs_sql), params return 'django_timestamp_diff(%s, %s)' % (lhs_sql, rhs_sql), params - def insert_statement(self, ignore_conflicts=False): - return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts) + def insert_statement(self, on_conflict=None): + if on_conflict == OnConflict.IGNORE: + return 'INSERT OR IGNORE INTO' + return super().insert_statement(on_conflict=on_conflict) def return_insert_columns(self, fields): # SQLite < 3.35 doesn't support an INSERT...RETURNING statement. @@ -384,3 +387,19 @@ class DatabaseOperations(BaseDatabaseOperations): ) for field in fields ] return 'RETURNING %s' % ', '.join(columns), () + + def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields): + if ( + on_conflict == OnConflict.UPDATE and + self.connection.features.supports_update_conflicts_with_target + ): + return 'ON CONFLICT(%s) DO UPDATE SET %s' % ( + ', '.join(map(self.quote_name, unique_fields)), + ', '.join([ + f'{field} = EXCLUDED.{field}' + for field in map(self.quote_name, update_fields) + ]), + ) + return super().on_conflict_suffix_sql( + fields, on_conflict, update_fields, unique_fields, + ) diff --git a/django/db/models/constants.py b/django/db/models/constants.py index a7e6c252d9..95addd2ab0 100644 --- a/django/db/models/constants.py +++ b/django/db/models/constants.py @@ -1,6 +1,12 @@ """ Constants used across the ORM in general. """ +from enum import Enum # Separator used to split filter strings apart. LOOKUP_SEP = '__' + + +class OnConflict(Enum): + IGNORE = 'ignore' + UPDATE = 'update' diff --git a/django/db/models/query.py b/django/db/models/query.py index 86b1631f67..1874416928 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -15,7 +15,7 @@ from django.db import ( router, transaction, ) from django.db.models import AutoField, DateField, DateTimeField, sql -from django.db.models.constants import LOOKUP_SEP +from django.db.models.constants import LOOKUP_SEP, OnConflict from django.db.models.deletion import Collector from django.db.models.expressions import Case, Expression, F, Ref, Value, When from django.db.models.functions import Cast, Trunc @@ -466,7 +466,69 @@ class QuerySet: obj.pk = obj._meta.pk.get_pk_value_on_save(obj) obj._prepare_related_fields_for_save(operation_name='bulk_create') - def bulk_create(self, objs, batch_size=None, ignore_conflicts=False): + def _check_bulk_create_options(self, ignore_conflicts, update_conflicts, update_fields, unique_fields): + if ignore_conflicts and update_conflicts: + raise ValueError( + 'ignore_conflicts and update_conflicts are mutually exclusive.' + ) + db_features = connections[self.db].features + if ignore_conflicts: + if not db_features.supports_ignore_conflicts: + raise NotSupportedError( + 'This database backend does not support ignoring conflicts.' + ) + return OnConflict.IGNORE + elif update_conflicts: + if not db_features.supports_update_conflicts: + raise NotSupportedError( + 'This database backend does not support updating conflicts.' + ) + if not update_fields: + raise ValueError( + 'Fields that will be updated when a row insertion fails ' + 'on conflicts must be provided.' + ) + if unique_fields and not db_features.supports_update_conflicts_with_target: + raise NotSupportedError( + 'This database backend does not support updating ' + 'conflicts with specifying unique fields that can trigger ' + 'the upsert.' + ) + if not unique_fields and db_features.supports_update_conflicts_with_target: + raise ValueError( + 'Unique fields that can trigger the upsert must be ' + 'provided.' + ) + # Updating primary keys and non-concrete fields is forbidden. + update_fields = [self.model._meta.get_field(name) for name in update_fields] + if any(not f.concrete or f.many_to_many for f in update_fields): + raise ValueError( + 'bulk_create() can only be used with concrete fields in ' + 'update_fields.' + ) + if any(f.primary_key for f in update_fields): + raise ValueError( + 'bulk_create() cannot be used with primary keys in ' + 'update_fields.' + ) + if unique_fields: + # Primary key is allowed in unique_fields. + unique_fields = [ + self.model._meta.get_field(name) + for name in unique_fields if name != 'pk' + ] + if any(not f.concrete or f.many_to_many for f in unique_fields): + raise ValueError( + 'bulk_create() can only be used with concrete fields ' + 'in unique_fields.' + ) + return OnConflict.UPDATE + return None + + def bulk_create( + self, objs, batch_size=None, ignore_conflicts=False, + update_conflicts=False, update_fields=None, unique_fields=None, + ): """ Insert each of the instances into the database. Do *not* call save() on each of the instances, do not send any pre/post_save @@ -497,6 +559,12 @@ class QuerySet: raise ValueError("Can't bulk create a multi-table inherited model") if not objs: return objs + on_conflict = self._check_bulk_create_options( + ignore_conflicts, + update_conflicts, + update_fields, + unique_fields, + ) self._for_write = True opts = self.model._meta fields = opts.concrete_fields @@ -506,7 +574,12 @@ class QuerySet: objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs) if objs_with_pk: returned_columns = self._batched_insert( - objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, + objs_with_pk, + fields, + batch_size, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, ) for obj_with_pk, results in zip(objs_with_pk, returned_columns): for result, field in zip(results, opts.db_returning_fields): @@ -518,10 +591,15 @@ class QuerySet: if objs_without_pk: fields = [f for f in fields if not isinstance(f, AutoField)] returned_columns = self._batched_insert( - objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts, + objs_without_pk, + fields, + batch_size, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, ) connection = connections[self.db] - if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts: + if connection.features.can_return_rows_from_bulk_insert and on_conflict is None: assert len(returned_columns) == len(objs_without_pk) for obj_without_pk, results in zip(objs_without_pk, returned_columns): for result, field in zip(results, opts.db_returning_fields): @@ -1293,7 +1371,10 @@ class QuerySet: # PRIVATE METHODS # ################### - def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False): + def _insert( + self, objs, fields, returning_fields=None, raw=False, using=None, + on_conflict=None, update_fields=None, unique_fields=None, + ): """ Insert a new record for the given model. This provides an interface to the InsertQuery class and is how Model.save() is implemented. @@ -1301,33 +1382,45 @@ class QuerySet: self._for_write = True if using is None: using = self.db - query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts) + query = sql.InsertQuery( + self.model, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + ) query.insert_values(fields, objs, raw=raw) return query.get_compiler(using=using).execute_sql(returning_fields) _insert.alters_data = True _insert.queryset_only = False - def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False): + def _batched_insert( + self, objs, fields, batch_size, on_conflict=None, update_fields=None, + unique_fields=None, + ): """ Helper method for bulk_create() to insert objs one batch at a time. """ connection = connections[self.db] - if ignore_conflicts and not connection.features.supports_ignore_conflicts: - raise NotSupportedError('This database backend does not support ignoring conflicts.') ops = connection.ops max_batch_size = max(ops.bulk_batch_size(fields, objs), 1) batch_size = min(batch_size, max_batch_size) if batch_size else max_batch_size inserted_rows = [] bulk_return = connection.features.can_return_rows_from_bulk_insert for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]: - if bulk_return and not ignore_conflicts: + if bulk_return and on_conflict is None: inserted_rows.extend(self._insert( item, fields=fields, using=self.db, returning_fields=self.model._meta.db_returning_fields, - ignore_conflicts=ignore_conflicts, )) else: - self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts) + self._insert( + item, + fields=fields, + using=self.db, + on_conflict=on_conflict, + update_fields=update_fields, + unique_fields=unique_fields, + ) return inserted_rows def _chain(self): diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 928ab40254..d405a203ee 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1387,7 +1387,9 @@ class SQLInsertCompiler(SQLCompiler): # going to be column names (so we can avoid the extra overhead). qn = self.connection.ops.quote_name opts = self.query.get_meta() - insert_statement = self.connection.ops.insert_statement(ignore_conflicts=self.query.ignore_conflicts) + insert_statement = self.connection.ops.insert_statement( + on_conflict=self.query.on_conflict, + ) result = ['%s %s' % (insert_statement, qn(opts.db_table))] fields = self.query.fields or [opts.pk] result.append('(%s)' % ', '.join(qn(f.column) for f in fields)) @@ -1410,8 +1412,11 @@ class SQLInsertCompiler(SQLCompiler): placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) - ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql( - ignore_conflicts=self.query.ignore_conflicts + on_conflict_suffix_sql = self.connection.ops.on_conflict_suffix_sql( + fields, + self.query.on_conflict, + self.query.update_fields, + self.query.unique_fields, ) if self.returning_fields and self.connection.features.can_return_columns_from_insert: if self.connection.features.can_return_rows_from_bulk_insert: @@ -1420,8 +1425,8 @@ class SQLInsertCompiler(SQLCompiler): else: result.append("VALUES (%s)" % ", ".join(placeholder_rows[0])) params = [param_rows[0]] - if ignore_conflicts_suffix_sql: - result.append(ignore_conflicts_suffix_sql) + if on_conflict_suffix_sql: + result.append(on_conflict_suffix_sql) # Skip empty r_sql to allow subclasses to customize behavior for # 3rd party backends. Refs #19096. r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields) @@ -1432,12 +1437,12 @@ class SQLInsertCompiler(SQLCompiler): if can_bulk: result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) - if ignore_conflicts_suffix_sql: - result.append(ignore_conflicts_suffix_sql) + if on_conflict_suffix_sql: + result.append(on_conflict_suffix_sql) return [(" ".join(result), tuple(p for ps in param_rows for p in ps))] else: - if ignore_conflicts_suffix_sql: - result.append(ignore_conflicts_suffix_sql) + if on_conflict_suffix_sql: + result.append(on_conflict_suffix_sql) return [ (" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals) for p, vals in zip(placeholder_rows, param_rows) diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index b1e5d2f5b7..f6a371a925 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -138,11 +138,13 @@ class UpdateQuery(Query): class InsertQuery(Query): compiler = 'SQLInsertCompiler' - def __init__(self, *args, ignore_conflicts=False, **kwargs): + def __init__(self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs): super().__init__(*args, **kwargs) self.fields = [] self.objs = [] - self.ignore_conflicts = ignore_conflicts + self.on_conflict = on_conflict + self.update_fields = update_fields or [] + self.unique_fields = unique_fields or [] def insert_values(self, fields, objs, raw=False): self.fields = fields diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index e0d4311b15..60360ee163 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -2155,7 +2155,7 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised. ``bulk_create()`` ~~~~~~~~~~~~~~~~~ -.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False) +.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False, update_conflicts=False, update_fields=None, unique_fields=None) This method inserts the provided list of objects into the database in an efficient manner (generally only 1 query, no matter how many objects there @@ -2198,9 +2198,17 @@ where the default is such that at most 999 variables per query are used. On databases that support it (all but Oracle), setting the ``ignore_conflicts`` parameter to ``True`` tells the database to ignore failure to insert any rows -that fail constraints such as duplicate unique values. Enabling this parameter -disables setting the primary key on each model instance (if the database -normally supports it). +that fail constraints such as duplicate unique values. + +On databases that support it (all except Oracle and SQLite < 3.24), setting the +``update_conflicts`` parameter to ``True``, tells the database to update +``update_fields`` when a row insertion fails on conflicts. On PostgreSQL and +SQLite, in addition to ``update_fields``, a list of ``unique_fields`` that may +be in conflict must be provided. + +Enabling the ``ignore_conflicts`` or ``update_conflicts`` parameter disable +setting the primary key on each model instance (if the database normally +support it). .. warning:: @@ -2217,6 +2225,12 @@ normally supports it). Support for the fetching primary key attributes on SQLite 3.35+ was added. +.. versionchanged:: 4.1 + + The ``update_conflicts``, ``update_fields``, and ``unique_fields`` + parameters were added to support updating fields when a row insertion fails + on conflict. + ``bulk_update()`` ~~~~~~~~~~~~~~~~~ diff --git a/docs/releases/4.1.txt b/docs/releases/4.1.txt index b6112fe29e..8fc23e6141 100644 --- a/docs/releases/4.1.txt +++ b/docs/releases/4.1.txt @@ -232,6 +232,10 @@ Models in order to reduce the number of failed requests, e.g. after database server restart. +* :meth:`.QuerySet.bulk_create` now supports updating fields when a row + insertion fails uniqueness constraints. This is supported on MariaDB, MySQL, + PostgreSQL, and SQLite 3.24+. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ @@ -298,6 +302,14 @@ backends. * ``DatabaseIntrospection.get_key_columns()`` is removed. Use ``DatabaseIntrospection.get_relations()`` instead. +* ``DatabaseOperations.ignore_conflicts_suffix_sql()`` method is replaced by + ``DatabaseOperations.on_conflict_suffix_sql()`` that accepts the ``fields``, + ``on_conflict``, ``update_fields``, and ``unique_fields`` arguments. + +* The ``ignore_conflicts`` argument of the + ``DatabaseOperations.insert_statement()`` method is replaced by + ``on_conflict`` that accepts ``django.db.models.constants.OnConflict``. + Dropped support for MariaDB 10.2 -------------------------------- diff --git a/tests/bulk_create/models.py b/tests/bulk_create/models.py index 586457b192..f0db69932e 100644 --- a/tests/bulk_create/models.py +++ b/tests/bulk_create/models.py @@ -16,6 +16,14 @@ class Country(models.Model): iso_two_letter = models.CharField(max_length=2) description = models.TextField() + class Meta: + constraints = [ + models.UniqueConstraint( + fields=['iso_two_letter', 'name'], + name='country_name_iso_unique', + ), + ] + class ProxyCountry(Country): class Meta: @@ -58,6 +66,13 @@ class State(models.Model): class TwoFields(models.Model): f1 = models.IntegerField(unique=True) f2 = models.IntegerField(unique=True) + name = models.CharField(max_length=15, null=True) + + +class UpsertConflict(models.Model): + number = models.IntegerField(unique=True) + rank = models.IntegerField() + name = models.CharField(max_length=15) class NoFields(models.Model): @@ -103,3 +118,9 @@ class NullableFields(models.Model): text_field = models.TextField(null=True, default='text') url_field = models.URLField(null=True, default='/') uuid_field = models.UUIDField(null=True, default=uuid.uuid4) + + +class RelatedModel(models.Model): + name = models.CharField(max_length=15, null=True) + country = models.OneToOneField(Country, models.CASCADE, primary_key=True) + big_auto_fields = models.ManyToManyField(BigAutoFieldModel) diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py index 2ee54c382f..7e5ff32380 100644 --- a/tests/bulk_create/tests.py +++ b/tests/bulk_create/tests.py @@ -1,7 +1,11 @@ from math import ceil from operator import attrgetter -from django.db import IntegrityError, NotSupportedError, connection +from django.core.exceptions import FieldDoesNotExist +from django.db import ( + IntegrityError, NotSupportedError, OperationalError, ProgrammingError, + connection, +) from django.db.models import FileField, Value from django.db.models.functions import Lower from django.test import ( @@ -11,7 +15,8 @@ from django.test import ( from .models import ( BigAutoFieldModel, Country, NoFields, NullableFields, Pizzeria, ProxyCountry, ProxyMultiCountry, ProxyMultiProxyCountry, ProxyProxyCountry, - Restaurant, SmallAutoFieldModel, State, TwoFields, + RelatedModel, Restaurant, SmallAutoFieldModel, State, TwoFields, + UpsertConflict, ) @@ -53,10 +58,10 @@ class BulkCreateTests(TestCase): @skipUnlessDBFeature('has_bulk_insert') def test_long_and_short_text(self): Country.objects.bulk_create([ - Country(description='a' * 4001), - Country(description='a'), - Country(description='Ж' * 2001), - Country(description='Ж'), + Country(description='a' * 4001, iso_two_letter='A'), + Country(description='a', iso_two_letter='B'), + Country(description='Ж' * 2001, iso_two_letter='C'), + Country(description='Ж', iso_two_letter='D'), ]) self.assertEqual(Country.objects.count(), 4) @@ -218,7 +223,7 @@ class BulkCreateTests(TestCase): @skipUnlessDBFeature('has_bulk_insert') def test_explicit_batch_size_respects_max_batch_size(self): - objs = [Country() for i in range(1000)] + objs = [Country(name=f'Country {i}') for i in range(1000)] fields = ['name', 'iso_two_letter', 'description'] max_batch_size = max(connection.ops.bulk_batch_size(fields, objs), 1) with self.assertNumQueries(ceil(len(objs) / max_batch_size)): @@ -352,3 +357,276 @@ class BulkCreateTests(TestCase): msg = 'Batch size must be a positive integer.' with self.assertRaisesMessage(ValueError, msg): Country.objects.bulk_create([], batch_size=-1) + + @skipIfDBFeature('supports_update_conflicts') + def test_update_conflicts_unsupported(self): + msg = 'This database backend does not support updating conflicts.' + with self.assertRaisesMessage(NotSupportedError, msg): + Country.objects.bulk_create(self.data, update_conflicts=True) + + @skipUnlessDBFeature('supports_ignore_conflicts', 'supports_update_conflicts') + def test_ignore_update_conflicts_exclusive(self): + msg = 'ignore_conflicts and update_conflicts are mutually exclusive' + with self.assertRaisesMessage(ValueError, msg): + Country.objects.bulk_create( + self.data, + ignore_conflicts=True, + update_conflicts=True, + ) + + @skipUnlessDBFeature('supports_update_conflicts') + def test_update_conflicts_no_update_fields(self): + msg = ( + 'Fields that will be updated when a row insertion fails on ' + 'conflicts must be provided.' + ) + with self.assertRaisesMessage(ValueError, msg): + Country.objects.bulk_create(self.data, update_conflicts=True) + + @skipUnlessDBFeature('supports_update_conflicts') + @skipIfDBFeature('supports_update_conflicts_with_target') + def test_update_conflicts_unique_field_unsupported(self): + msg = ( + 'This database backend does not support updating conflicts with ' + 'specifying unique fields that can trigger the upsert.' + ) + with self.assertRaisesMessage(NotSupportedError, msg): + TwoFields.objects.bulk_create( + [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)], + update_conflicts=True, + update_fields=['f2'], + unique_fields=['f1'], + ) + + @skipUnlessDBFeature('supports_update_conflicts') + def test_update_conflicts_nonexistent_update_fields(self): + unique_fields = None + if connection.features.supports_update_conflicts_with_target: + unique_fields = ['f1'] + msg = "TwoFields has no field named 'nonexistent'" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + TwoFields.objects.bulk_create( + [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)], + update_conflicts=True, + update_fields=['nonexistent'], + unique_fields=unique_fields, + ) + + @skipUnlessDBFeature( + 'supports_update_conflicts', 'supports_update_conflicts_with_target', + ) + def test_update_conflicts_unique_fields_required(self): + msg = 'Unique fields that can trigger the upsert must be provided.' + with self.assertRaisesMessage(ValueError, msg): + TwoFields.objects.bulk_create( + [TwoFields(f1=1, f2=1), TwoFields(f1=2, f2=2)], + update_conflicts=True, + update_fields=['f1'], + ) + + @skipUnlessDBFeature( + 'supports_update_conflicts', 'supports_update_conflicts_with_target', + ) + def test_update_conflicts_invalid_update_fields(self): + msg = ( + 'bulk_create() can only be used with concrete fields in ' + 'update_fields.' + ) + # Reverse one-to-one relationship. + with self.assertRaisesMessage(ValueError, msg): + Country.objects.bulk_create( + self.data, + update_conflicts=True, + update_fields=['relatedmodel'], + unique_fields=['pk'], + ) + # Many-to-many relationship. + with self.assertRaisesMessage(ValueError, msg): + RelatedModel.objects.bulk_create( + [RelatedModel(country=self.data[0])], + update_conflicts=True, + update_fields=['big_auto_fields'], + unique_fields=['country'], + ) + + @skipUnlessDBFeature( + 'supports_update_conflicts', 'supports_update_conflicts_with_target', + ) + def test_update_conflicts_pk_in_update_fields(self): + msg = 'bulk_create() cannot be used with primary keys in update_fields.' + with self.assertRaisesMessage(ValueError, msg): + BigAutoFieldModel.objects.bulk_create( + [BigAutoFieldModel()], + update_conflicts=True, + update_fields=['id'], + unique_fields=['id'], + ) + + @skipUnlessDBFeature( + 'supports_update_conflicts', 'supports_update_conflicts_with_target', + ) + def test_update_conflicts_invalid_unique_fields(self): + msg = ( + 'bulk_create() can only be used with concrete fields in ' + 'unique_fields.' + ) + # Reverse one-to-one relationship. + with self.assertRaisesMessage(ValueError, msg): + Country.objects.bulk_create( + self.data, + update_conflicts=True, + update_fields=['name'], + unique_fields=['relatedmodel'], + ) + # Many-to-many relationship. + with self.assertRaisesMessage(ValueError, msg): + RelatedModel.objects.bulk_create( + [RelatedModel(country=self.data[0])], + update_conflicts=True, + update_fields=['name'], + unique_fields=['big_auto_fields'], + ) + + def _test_update_conflicts_two_fields(self, unique_fields): + TwoFields.objects.bulk_create([ + TwoFields(f1=1, f2=1, name='a'), + TwoFields(f1=2, f2=2, name='b'), + ]) + self.assertEqual(TwoFields.objects.count(), 2) + + conflicting_objects = [ + TwoFields(f1=1, f2=1, name='c'), + TwoFields(f1=2, f2=2, name='d'), + ] + TwoFields.objects.bulk_create( + conflicting_objects, + update_conflicts=True, + unique_fields=unique_fields, + update_fields=['name'], + ) + self.assertEqual(TwoFields.objects.count(), 2) + self.assertCountEqual(TwoFields.objects.values('f1', 'f2', 'name'), [ + {'f1': 1, 'f2': 1, 'name': 'c'}, + {'f1': 2, 'f2': 2, 'name': 'd'}, + ]) + + @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') + def test_update_conflicts_two_fields_unique_fields_first(self): + self._test_update_conflicts_two_fields(['f1']) + + @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') + def test_update_conflicts_two_fields_unique_fields_second(self): + self._test_update_conflicts_two_fields(['f2']) + + @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') + def test_update_conflicts_two_fields_unique_fields_both(self): + with self.assertRaises((OperationalError, ProgrammingError)): + self._test_update_conflicts_two_fields(['f1', 'f2']) + + @skipUnlessDBFeature('supports_update_conflicts') + @skipIfDBFeature('supports_update_conflicts_with_target') + def test_update_conflicts_two_fields_no_unique_fields(self): + self._test_update_conflicts_two_fields([]) + + def _test_update_conflicts_unique_two_fields(self, unique_fields): + Country.objects.bulk_create(self.data) + self.assertEqual(Country.objects.count(), 4) + + new_data = [ + # Conflicting countries. + Country(name='Germany', iso_two_letter='DE', description=( + 'Germany is a country in Central Europe.' + )), + Country(name='Czech Republic', iso_two_letter='CZ', description=( + 'The Czech Republic is a landlocked country in Central Europe.' + )), + # New countries. + Country(name='Australia', iso_two_letter='AU'), + Country(name='Japan', iso_two_letter='JP', description=( + 'Japan is an island country in East Asia.' + )), + ] + Country.objects.bulk_create( + new_data, + update_conflicts=True, + update_fields=['description'], + unique_fields=unique_fields, + ) + self.assertEqual(Country.objects.count(), 6) + self.assertCountEqual(Country.objects.values('iso_two_letter', 'description'), [ + {'iso_two_letter': 'US', 'description': ''}, + {'iso_two_letter': 'NL', 'description': ''}, + {'iso_two_letter': 'DE', 'description': ( + 'Germany is a country in Central Europe.' + )}, + {'iso_two_letter': 'CZ', 'description': ( + 'The Czech Republic is a landlocked country in Central Europe.' + )}, + {'iso_two_letter': 'AU', 'description': ''}, + {'iso_two_letter': 'JP', 'description': ( + 'Japan is an island country in East Asia.' + )}, + ]) + + @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') + def test_update_conflicts_unique_two_fields_unique_fields_both(self): + self._test_update_conflicts_unique_two_fields(['iso_two_letter', 'name']) + + @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') + def test_update_conflicts_unique_two_fields_unique_fields_one(self): + with self.assertRaises((OperationalError, ProgrammingError)): + self._test_update_conflicts_unique_two_fields(['iso_two_letter']) + + @skipUnlessDBFeature('supports_update_conflicts') + @skipIfDBFeature('supports_update_conflicts_with_target') + def test_update_conflicts_unique_two_fields_unique_no_unique_fields(self): + self._test_update_conflicts_unique_two_fields([]) + + def _test_update_conflicts(self, unique_fields): + UpsertConflict.objects.bulk_create([ + UpsertConflict(number=1, rank=1, name='John'), + UpsertConflict(number=2, rank=2, name='Mary'), + UpsertConflict(number=3, rank=3, name='Hannah'), + ]) + self.assertEqual(UpsertConflict.objects.count(), 3) + + conflicting_objects = [ + UpsertConflict(number=1, rank=4, name='Steve'), + UpsertConflict(number=2, rank=2, name='Olivia'), + UpsertConflict(number=3, rank=1, name='Hannah'), + ] + UpsertConflict.objects.bulk_create( + conflicting_objects, + update_conflicts=True, + update_fields=['name', 'rank'], + unique_fields=unique_fields, + ) + self.assertEqual(UpsertConflict.objects.count(), 3) + self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [ + {'number': 1, 'rank': 4, 'name': 'Steve'}, + {'number': 2, 'rank': 2, 'name': 'Olivia'}, + {'number': 3, 'rank': 1, 'name': 'Hannah'}, + ]) + + UpsertConflict.objects.bulk_create( + conflicting_objects + [UpsertConflict(number=4, rank=4, name='Mark')], + update_conflicts=True, + update_fields=['name', 'rank'], + unique_fields=unique_fields, + ) + self.assertEqual(UpsertConflict.objects.count(), 4) + self.assertCountEqual(UpsertConflict.objects.values('number', 'rank', 'name'), [ + {'number': 1, 'rank': 4, 'name': 'Steve'}, + {'number': 2, 'rank': 2, 'name': 'Olivia'}, + {'number': 3, 'rank': 1, 'name': 'Hannah'}, + {'number': 4, 'rank': 4, 'name': 'Mark'}, + ]) + + @skipUnlessDBFeature('supports_update_conflicts', 'supports_update_conflicts_with_target') + def test_update_conflicts_unique_fields(self): + self._test_update_conflicts(unique_fields=['number']) + + @skipUnlessDBFeature('supports_update_conflicts') + @skipIfDBFeature('supports_update_conflicts_with_target') + def test_update_conflicts_no_unique_fields(self): + self._test_update_conflicts([])