mirror of https://github.com/django/django.git
Fixed #28668 -- Allowed QuerySet.bulk_create() to ignore insert conflicts.
This commit is contained in:
parent
45086c294d
commit
f1fbef6cd1
|
@ -261,6 +261,10 @@ class BaseDatabaseFeatures:
|
|||
# Does the backend support the default parameter in lead() and lag()?
|
||||
supports_default_in_lead_lag = True
|
||||
|
||||
# Does the backend support ignoring constraint or uniqueness errors during
|
||||
# INSERT?
|
||||
supports_ignore_conflicts = True
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
|
|
|
@ -670,3 +670,9 @@ class BaseDatabaseOperations:
|
|||
if options:
|
||||
raise ValueError('Unknown options: %s' % ', '.join(sorted(options.keys())))
|
||||
return self.explain_prefix
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT INTO'
|
||||
|
||||
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
|
||||
return ''
|
||||
|
|
|
@ -308,3 +308,6 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
|
||||
match_option = 'c' if lookup_type == 'regex' else 'i'
|
||||
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
|
||||
|
|
|
@ -53,4 +53,5 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
"""
|
||||
supports_callproc_kwargs = True
|
||||
supports_over_clause = True
|
||||
supports_ignore_conflicts = False
|
||||
max_query_params = 2**16 - 1
|
||||
|
|
|
@ -66,3 +66,4 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
has_jsonb_agg = is_postgresql_9_5
|
||||
has_brin_autosummarize = is_postgresql_10
|
||||
has_gin_pending_list_limit = is_postgresql_9_5
|
||||
supports_ignore_conflicts = is_postgresql_9_5
|
||||
|
|
|
@ -277,3 +277,6 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
if extra:
|
||||
prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
|
||||
return prefix
|
||||
|
||||
def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
|
||||
return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
|
||||
|
|
|
@ -297,3 +297,6 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
if internal_type == 'TimeField':
|
||||
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
|
||||
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
|
||||
|
||||
def insert_statement(self, ignore_conflicts=False):
|
||||
return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
|
||||
|
|
|
@ -23,6 +23,7 @@ from django.db.models.fields import AutoField
|
|||
from django.db.models.functions import Trunc
|
||||
from django.db.models.query_utils import FilteredRelation, InvalidQuery, Q
|
||||
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
|
||||
from django.db.utils import NotSupportedError
|
||||
from django.utils import timezone
|
||||
from django.utils.deprecation import RemovedInDjango30Warning
|
||||
from django.utils.functional import cached_property, partition
|
||||
|
@ -418,7 +419,7 @@ class QuerySet:
|
|||
if obj.pk is None:
|
||||
obj.pk = obj._meta.pk.get_pk_value_on_save(obj)
|
||||
|
||||
def bulk_create(self, objs, batch_size=None):
|
||||
def bulk_create(self, objs, batch_size=None, ignore_conflicts=False):
|
||||
"""
|
||||
Insert each of the instances into the database. Do *not* call
|
||||
save() on each of the instances, do not send any pre/post_save
|
||||
|
@ -456,14 +457,14 @@ class QuerySet:
|
|||
with transaction.atomic(using=self.db, savepoint=False):
|
||||
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
|
||||
if objs_with_pk:
|
||||
self._batched_insert(objs_with_pk, fields, batch_size)
|
||||
self._batched_insert(objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts)
|
||||
for obj_with_pk in objs_with_pk:
|
||||
obj_with_pk._state.adding = False
|
||||
obj_with_pk._state.db = self.db
|
||||
if objs_without_pk:
|
||||
fields = [f for f in fields if not isinstance(f, AutoField)]
|
||||
ids = self._batched_insert(objs_without_pk, fields, batch_size)
|
||||
if connection.features.can_return_ids_from_bulk_insert:
|
||||
ids = self._batched_insert(objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts)
|
||||
if connection.features.can_return_ids_from_bulk_insert and not ignore_conflicts:
|
||||
assert len(ids) == len(objs_without_pk)
|
||||
for obj_without_pk, pk in zip(objs_without_pk, ids):
|
||||
obj_without_pk.pk = pk
|
||||
|
@ -1120,7 +1121,7 @@ class QuerySet:
|
|||
# PRIVATE METHODS #
|
||||
###################
|
||||
|
||||
def _insert(self, objs, fields, return_id=False, raw=False, using=None):
|
||||
def _insert(self, objs, fields, return_id=False, raw=False, using=None, ignore_conflicts=False):
|
||||
"""
|
||||
Insert a new record for the given model. This provides an interface to
|
||||
the InsertQuery class and is how Model.save() is implemented.
|
||||
|
@ -1128,28 +1129,34 @@ class QuerySet:
|
|||
self._for_write = True
|
||||
if using is None:
|
||||
using = self.db
|
||||
query = sql.InsertQuery(self.model)
|
||||
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
|
||||
query.insert_values(fields, objs, raw=raw)
|
||||
return query.get_compiler(using=using).execute_sql(return_id)
|
||||
_insert.alters_data = True
|
||||
_insert.queryset_only = False
|
||||
|
||||
def _batched_insert(self, objs, fields, batch_size):
|
||||
def _batched_insert(self, objs, fields, batch_size, ignore_conflicts=False):
|
||||
"""
|
||||
Helper method for bulk_create() to insert objs one batch at a time.
|
||||
"""
|
||||
if ignore_conflicts and not connections[self.db].features.supports_ignore_conflicts:
|
||||
raise NotSupportedError('This database backend does not support ignoring conflicts.')
|
||||
ops = connections[self.db].ops
|
||||
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
|
||||
inserted_ids = []
|
||||
bulk_return = connections[self.db].features.can_return_ids_from_bulk_insert
|
||||
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
|
||||
if connections[self.db].features.can_return_ids_from_bulk_insert:
|
||||
inserted_id = self._insert(item, fields=fields, using=self.db, return_id=True)
|
||||
if bulk_return and not ignore_conflicts:
|
||||
inserted_id = self._insert(
|
||||
item, fields=fields, using=self.db, return_id=True,
|
||||
ignore_conflicts=ignore_conflicts,
|
||||
)
|
||||
if isinstance(inserted_id, list):
|
||||
inserted_ids.extend(inserted_id)
|
||||
else:
|
||||
inserted_ids.append(inserted_id)
|
||||
else:
|
||||
self._insert(item, fields=fields, using=self.db)
|
||||
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
|
||||
return inserted_ids
|
||||
|
||||
def _chain(self, **kwargs):
|
||||
|
|
|
@ -1232,7 +1232,8 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
# going to be column names (so we can avoid the extra overhead).
|
||||
qn = self.connection.ops.quote_name
|
||||
opts = self.query.get_meta()
|
||||
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
||||
insert_statement = self.connection.ops.insert_statement(ignore_conflicts=self.query.ignore_conflicts)
|
||||
result = ['%s %s' % (insert_statement, qn(opts.db_table))]
|
||||
fields = self.query.fields or [opts.pk]
|
||||
result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
|
||||
|
||||
|
@ -1254,6 +1255,9 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
|
||||
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
|
||||
|
||||
ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql(
|
||||
ignore_conflicts=self.query.ignore_conflicts
|
||||
)
|
||||
if self.return_id and self.connection.features.can_return_id_from_insert:
|
||||
if self.connection.features.can_return_ids_from_bulk_insert:
|
||||
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
|
||||
|
@ -1261,6 +1265,8 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
else:
|
||||
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
|
||||
params = [param_rows[0]]
|
||||
if ignore_conflicts_suffix_sql:
|
||||
result.append(ignore_conflicts_suffix_sql)
|
||||
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
||||
r_fmt, r_params = self.connection.ops.return_insert_id()
|
||||
# Skip empty r_fmt to allow subclasses to customize behavior for
|
||||
|
@ -1272,8 +1278,12 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
|
||||
if can_bulk:
|
||||
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
|
||||
if ignore_conflicts_suffix_sql:
|
||||
result.append(ignore_conflicts_suffix_sql)
|
||||
return [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
|
||||
else:
|
||||
if ignore_conflicts_suffix_sql:
|
||||
result.append(ignore_conflicts_suffix_sql)
|
||||
return [
|
||||
(" ".join(result + ["VALUES (%s)" % ", ".join(p)]), vals)
|
||||
for p, vals in zip(placeholder_rows, param_rows)
|
||||
|
|
|
@ -169,10 +169,11 @@ class UpdateQuery(Query):
|
|||
class InsertQuery(Query):
|
||||
compiler = 'SQLInsertCompiler'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, ignore_conflicts=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields = []
|
||||
self.objs = []
|
||||
self.ignore_conflicts = ignore_conflicts
|
||||
|
||||
def insert_values(self, fields, objs, raw=False):
|
||||
self.fields = fields
|
||||
|
|
|
@ -2039,7 +2039,7 @@ exists in the database, an :exc:`~django.db.IntegrityError` is raised.
|
|||
``bulk_create()``
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. method:: bulk_create(objs, batch_size=None)
|
||||
.. method:: bulk_create(objs, batch_size=None, ignore_conflicts=False)
|
||||
|
||||
This method inserts the provided list of objects into the database in an
|
||||
efficient manner (generally only 1 query, no matter how many objects there
|
||||
|
@ -2079,6 +2079,16 @@ The ``batch_size`` parameter controls how many objects are created in a single
|
|||
query. The default is to create all objects in one batch, except for SQLite
|
||||
where the default is such that at most 999 variables per query are used.
|
||||
|
||||
On databases that support it (all except PostgreSQL < 9.5 and Oracle), setting
|
||||
the ``ignore_conflicts`` parameter to ``True`` tells the database to ignore
|
||||
failure to insert any rows that fail constraints such as duplicate unique
|
||||
values. Enabling this parameter disables setting the primary key on each model
|
||||
instance (if the database normally supports it).
|
||||
|
||||
.. versionchanged:: 2.2
|
||||
|
||||
The ``ignore_conflicts`` parameter was added.
|
||||
|
||||
``count()``
|
||||
~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -185,6 +185,10 @@ Models
|
|||
|
||||
* Added many :ref:`math database functions <math-functions>`.
|
||||
|
||||
* Setting the new ``ignore_conflicts`` parameter of
|
||||
:meth:`.QuerySet.bulk_create` to ``True`` tells the database to ignore
|
||||
failure to insert rows that fail uniqueness constraints or other checks.
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -237,6 +241,10 @@ Database backend API
|
|||
constraints or set ``DatabaseFeatures.supports_table_check_constraints`` to
|
||||
``False``.
|
||||
|
||||
* Third party database backends must implement support for ignoring
|
||||
constraints or uniqueness errors while inserting or set
|
||||
``DatabaseFeatures.supports_ignore_conflicts`` to ``False``.
|
||||
|
||||
:mod:`django.contrib.gis`
|
||||
-------------------------
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from operator import attrgetter
|
||||
|
||||
from django.db import connection
|
||||
from django.db import IntegrityError, NotSupportedError, connection
|
||||
from django.db.models import FileField, Value
|
||||
from django.db.models.functions import Lower
|
||||
from django.test import (
|
||||
|
@ -261,3 +261,37 @@ class BulkCreateTests(TestCase):
|
|||
# Objects save via bulk_create() and save() should have equal state.
|
||||
self.assertEqual(state_ca._state.adding, state_ny._state.adding)
|
||||
self.assertEqual(state_ca._state.db, state_ny._state.db)
|
||||
|
||||
@skipIfDBFeature('supports_ignore_conflicts')
|
||||
def test_ignore_conflicts_value_error(self):
|
||||
message = 'This database backend does not support ignoring conflicts.'
|
||||
with self.assertRaisesMessage(NotSupportedError, message):
|
||||
TwoFields.objects.bulk_create(self.data, ignore_conflicts=True)
|
||||
|
||||
@skipUnlessDBFeature('supports_ignore_conflicts')
|
||||
def test_ignore_conflicts_ignore(self):
|
||||
data = [
|
||||
TwoFields(f1=1, f2=1),
|
||||
TwoFields(f1=2, f2=2),
|
||||
TwoFields(f1=3, f2=3),
|
||||
]
|
||||
TwoFields.objects.bulk_create(data)
|
||||
self.assertEqual(TwoFields.objects.count(), 3)
|
||||
# With ignore_conflicts=True, conflicts are ignored.
|
||||
conflicting_objects = [
|
||||
TwoFields(f1=2, f2=2),
|
||||
TwoFields(f1=3, f2=3),
|
||||
]
|
||||
TwoFields.objects.bulk_create([conflicting_objects[0]], ignore_conflicts=True)
|
||||
TwoFields.objects.bulk_create(conflicting_objects, ignore_conflicts=True)
|
||||
self.assertEqual(TwoFields.objects.count(), 3)
|
||||
self.assertIsNone(conflicting_objects[0].pk)
|
||||
self.assertIsNone(conflicting_objects[1].pk)
|
||||
# New objects are created and conflicts are ignored.
|
||||
new_object = TwoFields(f1=4, f2=4)
|
||||
TwoFields.objects.bulk_create(conflicting_objects + [new_object], ignore_conflicts=True)
|
||||
self.assertEqual(TwoFields.objects.count(), 4)
|
||||
self.assertIsNone(new_object.pk)
|
||||
# Without ignore_conflicts=True, there's a problem.
|
||||
with self.assertRaises(IntegrityError):
|
||||
TwoFields.objects.bulk_create(conflicting_objects)
|
||||
|
|
Loading…
Reference in New Issue