Refs #19527 -- Allowed QuerySet.bulk_create() to set the primary key of its objects.
PostgreSQL support only. Thanks Vladislav Manchev and alesasnouski for working on the patch.
This commit is contained in:
parent
60633ef3de
commit
04240b2365
|
@ -24,6 +24,7 @@ class BaseDatabaseFeatures(object):
|
|||
|
||||
can_use_chunked_reads = True
|
||||
can_return_id_from_insert = False
|
||||
can_return_ids_from_bulk_insert = False
|
||||
has_bulk_insert = False
|
||||
uses_savepoints = False
|
||||
can_release_savepoints = False
|
||||
|
|
|
@ -5,6 +5,7 @@ from django.db.utils import InterfaceError
|
|||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
allows_group_by_selected_pks = True
|
||||
can_return_id_from_insert = True
|
||||
can_return_ids_from_bulk_insert = True
|
||||
has_real_datatype = True
|
||||
has_native_uuid_field = True
|
||||
has_native_duration_field = True
|
||||
|
|
|
@ -59,6 +59,14 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_ids(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table that has an auto-incrementing ID, return the
|
||||
list of newly created IDs.
|
||||
"""
|
||||
return [item[0] for item in cursor.fetchall()]
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = '%s'
|
||||
|
||||
|
|
|
@ -411,17 +411,21 @@ class QuerySet(object):
|
|||
Inserts each of the instances into the database. This does *not* call
|
||||
save() on each of the instances, does not send any pre/post save
|
||||
signals, and does not set the primary key attribute if it is an
|
||||
autoincrement field. Multi-table models are not supported.
|
||||
autoincrement field (except if features.can_return_ids_from_bulk_insert=True).
|
||||
Multi-table models are not supported.
|
||||
"""
|
||||
# So this case is fun. When you bulk insert you don't get the primary
|
||||
# keys back (if it's an autoincrement), so you can't insert into the
|
||||
# child tables which references this. There are two workarounds, 1)
|
||||
# this could be implemented if you didn't have an autoincrement pk,
|
||||
# and 2) you could do it by doing O(n) normal inserts into the parent
|
||||
# tables to get the primary keys back, and then doing a single bulk
|
||||
# insert into the childmost table. Some databases might allow doing
|
||||
# this by using RETURNING clause for the insert query. We're punting
|
||||
# on these for now because they are relatively rare cases.
|
||||
# When you bulk insert you don't get the primary keys back (if it's an
|
||||
# autoincrement, except if can_return_ids_from_bulk_insert=True), so
|
||||
# you can't insert into the child tables which references this. There
|
||||
# are two workarounds:
|
||||
# 1) This could be implemented if you didn't have an autoincrement pk
|
||||
# 2) You could do it by doing O(n) normal inserts into the parent
|
||||
# tables to get the primary keys back and then doing a single bulk
|
||||
# insert into the childmost table.
|
||||
# We currently set the primary keys on the objects when using
|
||||
# PostgreSQL via the RETURNING ID clause. It should be possible for
|
||||
# Oracle as well, but the semantics for extracting the primary keys is
|
||||
# trickier so it's not done yet.
|
||||
assert batch_size is None or batch_size > 0
|
||||
# Check that the parents share the same concrete model with the our
|
||||
# model to detect the inheritance pattern ConcreteGrandParent ->
|
||||
|
@ -447,7 +451,11 @@ class QuerySet(object):
|
|||
self._batched_insert(objs_with_pk, fields, batch_size)
|
||||
if objs_without_pk:
|
||||
fields = [f for f in fields if not isinstance(f, AutoField)]
|
||||
self._batched_insert(objs_without_pk, fields, batch_size)
|
||||
ids = self._batched_insert(objs_without_pk, fields, batch_size)
|
||||
if connection.features.can_return_ids_from_bulk_insert:
|
||||
assert len(ids) == len(objs_without_pk)
|
||||
for i in range(len(ids)):
|
||||
objs_without_pk[i].pk = ids[i]
|
||||
|
||||
return objs
|
||||
|
||||
|
@ -1051,10 +1059,19 @@ class QuerySet(object):
|
|||
return
|
||||
ops = connections[self.db].ops
|
||||
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
|
||||
for batch in [objs[i:i + batch_size]
|
||||
for i in range(0, len(objs), batch_size)]:
|
||||
self.model._base_manager._insert(batch, fields=fields,
|
||||
using=self.db)
|
||||
inserted_ids = []
|
||||
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
|
||||
if connections[self.db].features.can_return_ids_from_bulk_insert:
|
||||
inserted_id = self.model._base_manager._insert(
|
||||
item, fields=fields, using=self.db, return_id=True
|
||||
)
|
||||
if len(objs) > 1:
|
||||
inserted_ids.extend(inserted_id)
|
||||
if len(objs) == 1:
|
||||
inserted_ids.append(inserted_id)
|
||||
else:
|
||||
self.model._base_manager._insert(item, fields=fields, using=self.db)
|
||||
return inserted_ids
|
||||
|
||||
def _clone(self, **kwargs):
|
||||
query = self.query.clone()
|
||||
|
|
|
@ -1019,16 +1019,20 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
|
||||
|
||||
if self.return_id and self.connection.features.can_return_id_from_insert:
|
||||
if self.connection.features.can_return_ids_from_bulk_insert:
|
||||
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
|
||||
params = param_rows
|
||||
else:
|
||||
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
|
||||
params = param_rows[0]
|
||||
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
||||
result.append("VALUES (%s)" % ", ".join(placeholder_rows[0]))
|
||||
r_fmt, r_params = self.connection.ops.return_insert_id()
|
||||
# Skip empty r_fmt to allow subclasses to customize behavior for
|
||||
# 3rd party backends. Refs #19096.
|
||||
if r_fmt:
|
||||
result.append(r_fmt % col)
|
||||
params += r_params
|
||||
return [(" ".join(result), tuple(params))]
|
||||
return [(" ".join(result), tuple(chain.from_iterable(params)))]
|
||||
|
||||
if can_bulk:
|
||||
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
|
||||
|
@ -1040,14 +1044,20 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
]
|
||||
|
||||
def execute_sql(self, return_id=False):
|
||||
assert not (return_id and len(self.query.objs) != 1)
|
||||
assert not (
|
||||
return_id and len(self.query.objs) != 1 and
|
||||
not self.connection.features.can_return_ids_from_bulk_insert
|
||||
)
|
||||
self.return_id = return_id
|
||||
with self.connection.cursor() as cursor:
|
||||
for sql, params in self.as_sql():
|
||||
cursor.execute(sql, params)
|
||||
if not (return_id and cursor):
|
||||
return
|
||||
if self.connection.features.can_return_ids_from_bulk_insert and len(self.query.objs) > 1:
|
||||
return self.connection.ops.fetch_returned_insert_ids(cursor)
|
||||
if self.connection.features.can_return_id_from_insert:
|
||||
assert len(self.query.objs) == 1
|
||||
return self.connection.ops.fetch_returned_insert_id(cursor)
|
||||
return self.connection.ops.last_insert_id(cursor,
|
||||
self.query.get_meta().db_table, self.query.get_meta().pk.column)
|
||||
|
|
|
@ -1794,13 +1794,19 @@ This has a number of caveats though:
|
|||
``post_save`` signals will not be sent.
|
||||
* It does not work with child models in a multi-table inheritance scenario.
|
||||
* If the model's primary key is an :class:`~django.db.models.AutoField` it
|
||||
does not retrieve and set the primary key attribute, as ``save()`` does.
|
||||
does not retrieve and set the primary key attribute, as ``save()`` does,
|
||||
unless the database backend supports it (currently PostgreSQL).
|
||||
* It does not work with many-to-many relationships.
|
||||
|
||||
.. versionchanged:: 1.9
|
||||
|
||||
Support for using ``bulk_create()`` with proxy models was added.
|
||||
|
||||
.. versionchanged:: 1.0
|
||||
|
||||
Support for setting primary keys on objects created using ``bulk_create()``
|
||||
when using PostgreSQL was added.
|
||||
|
||||
The ``batch_size`` parameter controls how many objects are created in single
|
||||
query. The default is to create all objects in one batch, except for SQLite
|
||||
where the default is such that at most 999 variables per query are used.
|
||||
|
|
|
@ -203,6 +203,11 @@ Database backends
|
|||
|
||||
* Temporal data subtraction was unified on all backends.
|
||||
|
||||
* If the database supports it, backends can set
|
||||
``DatabaseFeatures.can_return_ids_from_bulk_insert=True`` and implement
|
||||
``DatabaseOperations.fetch_returned_insert_ids()`` to set primary keys
|
||||
on objects created using ``QuerySet.bulk_create()``.
|
||||
|
||||
Email
|
||||
~~~~~
|
||||
|
||||
|
@ -315,6 +320,9 @@ Models
|
|||
* The :func:`~django.db.models.prefetch_related_objects` function is now a
|
||||
public API.
|
||||
|
||||
* :meth:`QuerySet.bulk_create() <django.db.models.query.QuerySet.bulk_create>`
|
||||
sets the primary key on objects when using PostgreSQL.
|
||||
|
||||
Requests and Responses
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -198,3 +198,22 @@ class BulkCreateTests(TestCase):
|
|||
])
|
||||
bbb = Restaurant.objects.filter(name="betty's beetroot bar")
|
||||
self.assertEqual(bbb.count(), 1)
|
||||
|
||||
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
|
||||
def test_set_pk_and_insert_single_item(self):
|
||||
countries = []
|
||||
with self.assertNumQueries(1):
|
||||
countries = Country.objects.bulk_create([self.data[0]])
|
||||
self.assertEqual(len(countries), 1)
|
||||
self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
|
||||
|
||||
@skipUnlessDBFeature('can_return_ids_from_bulk_insert')
|
||||
def test_set_pk_and_query_efficiency(self):
|
||||
countries = []
|
||||
with self.assertNumQueries(1):
|
||||
countries = Country.objects.bulk_create(self.data)
|
||||
self.assertEqual(len(countries), 4)
|
||||
self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0])
|
||||
self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1])
|
||||
self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2])
|
||||
self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3])
|
||||
|
|
Loading…
Reference in New Issue