Refs #29444 -- Allowed returning multiple fields from INSERT statements on PostgreSQL.

Thanks Florian Apolloner, Tim Graham, Simon Charette, Nick Pope, and
Mariusz Felisiak for reviews.
This commit is contained in:
Johannes Hoppe 2019-07-24 08:42:41 +02:00 committed by Mariusz Felisiak
parent 736e7d44de
commit 7254f1138d
16 changed files with 209 additions and 89 deletions

View File

@ -23,6 +23,7 @@ class BaseDatabaseFeatures:
can_use_chunked_reads = True
can_return_columns_from_insert = False
can_return_multiple_columns_from_insert = False
can_return_rows_from_bulk_insert = False
has_bulk_insert = True
uses_savepoints = True

View File

@ -176,13 +176,12 @@ class BaseDatabaseOperations:
else:
return ['DISTINCT'], []
def fetch_returned_insert_id(self, cursor):
def fetch_returned_insert_columns(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the
newly created ID.
statement into a table, return the newly created data.
"""
return cursor.fetchone()[0]
return cursor.fetchone()
def field_cast_sql(self, db_type, internal_type):
"""
@ -314,12 +313,11 @@ class BaseDatabaseOperations:
"""
return value
def return_insert_id(self, field):
def return_insert_columns(self, fields):
"""
For backends that support returning the last insert ID as part of an
insert query, return the SQL and params to append to the INSERT query.
The returned fragment should contain a format string to hold the
appropriate column.
For backends that support returning columns as part of an insert query,
return the SQL and params to append to the INSERT query. The returned
fragment should contain a format string to hold the appropriate column.
"""
pass

View File

@ -248,7 +248,7 @@ END;
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_id(self, cursor):
def fetch_returned_insert_columns(self, cursor):
value = cursor._insert_id_var.getvalue()
if value is None or value == []:
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
@ -258,7 +258,7 @@ END;
'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).'
)
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
return value[0] if isinstance(value, list) else value
return value if isinstance(value, list) else [value]
def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith('LOB'):
@ -341,8 +341,14 @@ END;
match_option = "'i'"
return 'REGEXP_LIKE(%%s, %%s, %s)' % match_option
def return_insert_id(self, field):
return 'RETURNING %s INTO %%s', (InsertVar(field),)
def return_insert_columns(self, fields):
if not fields:
return '', ()
sql = 'RETURNING %s.%s INTO %%s' % (
self.quote_name(fields[0].model._meta.db_table),
self.quote_name(fields[0].column),
)
return sql, (InsertVar(fields[0]),)
def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor:

View File

@ -8,6 +8,7 @@ from django.utils.functional import cached_property
class DatabaseFeatures(BaseDatabaseFeatures):
allows_group_by_selected_pks = True
can_return_columns_from_insert = True
can_return_multiple_columns_from_insert = True
can_return_rows_from_bulk_insert = True
has_real_datatype = True
has_native_uuid_field = True

View File

@ -76,13 +76,12 @@ class DatabaseOperations(BaseDatabaseOperations):
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_ids(self, cursor):
def fetch_returned_insert_rows(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the
list of newly created IDs.
statement into a table, return the tuple of returned data.
"""
return [item[0] for item in cursor.fetchall()]
return cursor.fetchall()
def lookup_cast(self, lookup_type, internal_type=None):
lookup = '%s'
@ -236,8 +235,16 @@ class DatabaseOperations(BaseDatabaseOperations):
return cursor.query.decode()
return None
def return_insert_id(self, field):
return "RETURNING %s", ()
def return_insert_columns(self, fields):
if not fields:
return '', ()
columns = [
'%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
) for field in fields
]
return 'RETURNING %s' % ', '.join(columns), ()
def bulk_insert_sql(self, fields, placeholder_rows):
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)

View File

@ -876,10 +876,10 @@ class Model(metaclass=ModelBase):
if not pk_set:
fields = [f for f in fields if f is not meta.auto_field]
update_pk = meta.auto_field and not pk_set
result = self._do_insert(cls._base_manager, using, fields, update_pk, raw)
if update_pk:
setattr(self, meta.pk.attname, result)
returning_fields = meta.db_returning_fields
results = self._do_insert(cls._base_manager, using, fields, returning_fields, raw)
for result, field in zip(results, returning_fields):
setattr(self, field.attname, result)
return updated
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update):
@ -909,13 +909,15 @@ class Model(metaclass=ModelBase):
)
return filtered._update(values) > 0
def _do_insert(self, manager, using, fields, update_pk, raw):
def _do_insert(self, manager, using, fields, returning_fields, raw):
"""
Do an INSERT. If update_pk is defined then this method should return
the new pk for the model.
Do an INSERT. If returning_fields is defined then this method should
return the newly created data for the model.
"""
return manager._insert([self], fields=fields, return_id=update_pk,
using=using, raw=raw)
return manager._insert(
[self], fields=fields, returning_fields=returning_fields,
using=using, raw=raw,
)
def delete(self, using=None, keep_parents=False):
using = using or router.db_for_write(self.__class__, instance=self)

View File

@ -735,6 +735,14 @@ class Field(RegisterLookupMixin):
def db_tablespace(self):
return self._db_tablespace or settings.DEFAULT_INDEX_TABLESPACE
@property
def db_returning(self):
"""
Private API intended only to be used by Django itself. Currently only
the PostgreSQL backend supports returning multiple fields on a model.
"""
return False
def set_attributes_from_name(self, name):
self.name = self.name or name
self.attname, self.column = self.get_attname_column()
@ -2311,6 +2319,7 @@ class UUIDField(Field):
class AutoFieldMixin:
db_returning = True
def __init__(self, *args, **kwargs):
kwargs['blank'] = True

View File

@ -842,3 +842,14 @@ class Options:
if isinstance(attr, property):
names.append(name)
return frozenset(names)
@cached_property
def db_returning_fields(self):
"""
Private API intended only to be used by Django itself.
Fields to be returned after a database insert.
"""
return [
field for field in self._get_fields(forward=True, reverse=False, include_parents=PROXY_PARENTS)
if getattr(field, 'db_returning', False)
]

View File

@ -470,23 +470,33 @@ class QuerySet:
return objs
self._for_write = True
connection = connections[self.db]
fields = self.model._meta.concrete_fields
opts = self.model._meta
fields = opts.concrete_fields
objs = list(objs)
self._populate_pk_values(objs)
with transaction.atomic(using=self.db, savepoint=False):
objs_with_pk, objs_without_pk = partition(lambda o: o.pk is None, objs)
if objs_with_pk:
self._batched_insert(objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts)
returned_columns = self._batched_insert(
objs_with_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
)
for obj_with_pk, results in zip(objs_with_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
if field != opts.pk:
setattr(obj_with_pk, field.attname, result)
for obj_with_pk in objs_with_pk:
obj_with_pk._state.adding = False
obj_with_pk._state.db = self.db
if objs_without_pk:
fields = [f for f in fields if not isinstance(f, AutoField)]
ids = self._batched_insert(objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts)
returned_columns = self._batched_insert(
objs_without_pk, fields, batch_size, ignore_conflicts=ignore_conflicts,
)
if connection.features.can_return_rows_from_bulk_insert and not ignore_conflicts:
assert len(ids) == len(objs_without_pk)
for obj_without_pk, pk in zip(objs_without_pk, ids):
obj_without_pk.pk = pk
assert len(returned_columns) == len(objs_without_pk)
for obj_without_pk, results in zip(objs_without_pk, returned_columns):
for result, field in zip(results, opts.db_returning_fields):
setattr(obj_without_pk, field.attname, result)
obj_without_pk._state.adding = False
obj_without_pk._state.db = self.db
@ -1181,7 +1191,7 @@ class QuerySet:
# PRIVATE METHODS #
###################
def _insert(self, objs, fields, return_id=False, raw=False, using=None, ignore_conflicts=False):
def _insert(self, objs, fields, returning_fields=None, raw=False, using=None, ignore_conflicts=False):
"""
Insert a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented.
@ -1191,7 +1201,7 @@ class QuerySet:
using = self.db
query = sql.InsertQuery(self.model, ignore_conflicts=ignore_conflicts)
query.insert_values(fields, objs, raw=raw)
return query.get_compiler(using=using).execute_sql(return_id)
return query.get_compiler(using=using).execute_sql(returning_fields)
_insert.alters_data = True
_insert.queryset_only = False
@ -1203,21 +1213,22 @@ class QuerySet:
raise NotSupportedError('This database backend does not support ignoring conflicts.')
ops = connections[self.db].ops
batch_size = (batch_size or max(ops.bulk_batch_size(fields, objs), 1))
inserted_ids = []
inserted_rows = []
bulk_return = connections[self.db].features.can_return_rows_from_bulk_insert
for item in [objs[i:i + batch_size] for i in range(0, len(objs), batch_size)]:
if bulk_return and not ignore_conflicts:
inserted_id = self._insert(
item, fields=fields, using=self.db, return_id=True,
inserted_columns = self._insert(
item, fields=fields, using=self.db,
returning_fields=self.model._meta.db_returning_fields,
ignore_conflicts=ignore_conflicts,
)
if isinstance(inserted_id, list):
inserted_ids.extend(inserted_id)
if isinstance(inserted_columns, list):
inserted_rows.extend(inserted_columns)
else:
inserted_ids.append(inserted_id)
inserted_rows.append(inserted_columns)
else:
self._insert(item, fields=fields, using=self.db, ignore_conflicts=ignore_conflicts)
return inserted_ids
return inserted_rows
def _chain(self, **kwargs):
"""

View File

@ -1159,7 +1159,7 @@ class SQLCompiler:
class SQLInsertCompiler(SQLCompiler):
return_id = False
returning_fields = None
def field_as_sql(self, field, val):
"""
@ -1290,14 +1290,14 @@ class SQLInsertCompiler(SQLCompiler):
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
can_bulk = (not self.return_id and self.connection.features.has_bulk_insert)
can_bulk = (not self.returning_fields and self.connection.features.has_bulk_insert)
placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)
ignore_conflicts_suffix_sql = self.connection.ops.ignore_conflicts_suffix_sql(
ignore_conflicts=self.query.ignore_conflicts
)
if self.return_id and self.connection.features.can_return_columns_from_insert:
if self.returning_fields and self.connection.features.can_return_columns_from_insert:
if self.connection.features.can_return_rows_from_bulk_insert:
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
params = param_rows
@ -1306,12 +1306,11 @@ class SQLInsertCompiler(SQLCompiler):
params = [param_rows[0]]
if ignore_conflicts_suffix_sql:
result.append(ignore_conflicts_suffix_sql)
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
r_fmt, r_params = self.connection.ops.return_insert_id(opts.pk)
# Skip empty r_fmt to allow subclasses to customize behavior for
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
if r_fmt:
result.append(r_fmt % col)
r_sql, r_params = self.connection.ops.return_insert_columns(self.returning_fields)
if r_sql:
result.append(r_sql)
params += [r_params]
return [(" ".join(result), tuple(chain.from_iterable(params)))]
@ -1328,25 +1327,33 @@ class SQLInsertCompiler(SQLCompiler):
for p, vals in zip(placeholder_rows, param_rows)
]
def execute_sql(self, return_id=False):
def execute_sql(self, returning_fields=None):
assert not (
return_id and len(self.query.objs) != 1 and
returning_fields and len(self.query.objs) != 1 and
not self.connection.features.can_return_rows_from_bulk_insert
)
self.return_id = return_id
self.returning_fields = returning_fields
with self.connection.cursor() as cursor:
for sql, params in self.as_sql():
cursor.execute(sql, params)
if not return_id:
return
if not self.returning_fields:
return []
if self.connection.features.can_return_rows_from_bulk_insert and len(self.query.objs) > 1:
return self.connection.ops.fetch_returned_insert_ids(cursor)
return self.connection.ops.fetch_returned_insert_rows(cursor)
if self.connection.features.can_return_columns_from_insert:
if (
len(self.returning_fields) > 1 and
not self.connection.features.can_return_multiple_columns_from_insert
):
raise NotSupportedError(
'Returning multiple columns from INSERT statements is '
'not supported on this database backend.'
)
assert len(self.query.objs) == 1
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(
return self.connection.ops.fetch_returned_insert_columns(cursor)
return [self.connection.ops.last_insert_id(
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
)
)]
class SQLDeleteCompiler(SQLCompiler):

View File

@ -448,14 +448,20 @@ backends.
:class:`~django.db.models.DateTimeField` in ``datetime_cast_date_sql()``,
``datetime_extract_sql()``, etc.
* ``DatabaseOperations.return_insert_id()`` now requires an additional
``field`` argument with the model field.
* Entries for ``AutoField``, ``BigAutoField``, and ``SmallAutoField`` are added
to ``DatabaseOperations.integer_field_ranges`` to support the integer range
validators on these field types. Third-party backends may need to customize
the default entries.
* ``DatabaseOperations.fetch_returned_insert_id()`` is replaced by
``fetch_returned_insert_columns()`` which returns a list of values returned
by the ``INSERT … RETURNING`` statement, instead of a single value.
* ``DatabaseOperations.return_insert_id()`` is replaced by
``return_insert_columns()`` that accepts a ``fields``
argument, which is an iterable of fields to be returned after insert. Usually
this is only the auto-generated primary key.
:mod:`django.contrib.admin`
---------------------------

View File

@ -5,10 +5,6 @@ from django.contrib.contenttypes.models import ContentType
from django.db import models
class NonIntegerAutoField(models.Model):
creation_datetime = models.DateTimeField(primary_key=True)
class Square(models.Model):
root = models.IntegerField()
square = models.PositiveIntegerField()

View File

@ -1,4 +1,3 @@
import datetime
import unittest
from django.db import connection
@ -6,7 +5,7 @@ from django.db.models.fields import BooleanField, NullBooleanField
from django.db.utils import DatabaseError
from django.test import TransactionTestCase
from ..models import NonIntegerAutoField, Square
from ..models import Square
@unittest.skipUnless(connection.vendor == 'oracle', 'Oracle tests')
@ -96,23 +95,3 @@ class TransactionalTests(TransactionTestCase):
self.assertIn('ORA-01017', context.exception.args[0].message)
finally:
connection.settings_dict['PASSWORD'] = old_password
def test_non_integer_auto_field(self):
with connection.cursor() as cursor:
# Create trigger that fill non-integer auto field.
cursor.execute("""
CREATE OR REPLACE TRIGGER "TRG_FILL_CREATION_DATETIME"
BEFORE INSERT ON "BACKENDS_NONINTEGERAUTOFIELD"
FOR EACH ROW
BEGIN
:NEW.CREATION_DATETIME := SYSTIMESTAMP;
END;
""")
try:
NonIntegerAutoField._meta.auto_field = NonIntegerAutoField.creation_datetime
obj = NonIntegerAutoField.objects.create()
self.assertIsNotNone(obj.creation_datetime)
self.assertIsInstance(obj.creation_datetime, datetime.datetime)
finally:
with connection.cursor() as cursor:
cursor.execute('DROP TRIGGER "TRG_FILL_CREATION_DATETIME"')

View File

@ -279,3 +279,8 @@ class PropertyNamesTests(SimpleTestCase):
# Instance only descriptors don't appear in _property_names.
self.assertEqual(AbstractPerson().test_instance_only_descriptor, 1)
self.assertEqual(AbstractPerson._meta._property_names, frozenset(['pk', 'test_property']))
class ReturningFieldsTests(SimpleTestCase):
def test_pk(self):
self.assertEqual(Relation._meta.db_returning_fields, [Relation._meta.pk])

View File

@ -4,6 +4,7 @@ Various complex queries that have been problematic in the past.
import threading
from django.db import models
from django.db.models.functions import Now
class DumbCategory(models.Model):
@ -730,3 +731,19 @@ class RelatedIndividual(models.Model):
class CustomDbColumn(models.Model):
custom_column = models.IntegerField(db_column='custom_name', null=True)
ip_address = models.GenericIPAddressField(null=True)
class CreatedField(models.DateTimeField):
db_returning = True
def __init__(self, *args, **kwargs):
kwargs.setdefault('default', Now)
super().__init__(*args, **kwargs)
class ReturningModel(models.Model):
created = CreatedField(editable=False)
class NonIntegerPKReturningModel(models.Model):
created = CreatedField(editable=False, primary_key=True)

View File

@ -0,0 +1,64 @@
import datetime
from django.db import NotSupportedError, connection
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from django.test.utils import CaptureQueriesContext
from .models import DumbCategory, NonIntegerPKReturningModel, ReturningModel
@skipUnlessDBFeature('can_return_columns_from_insert')
class ReturningValuesTests(TestCase):
def test_insert_returning(self):
with CaptureQueriesContext(connection) as captured_queries:
DumbCategory.objects.create()
self.assertIn(
'RETURNING %s.%s' % (
connection.ops.quote_name(DumbCategory._meta.db_table),
connection.ops.quote_name(DumbCategory._meta.get_field('id').column),
),
captured_queries[-1]['sql'],
)
def test_insert_returning_non_integer(self):
obj = NonIntegerPKReturningModel.objects.create()
self.assertTrue(obj.created)
self.assertIsInstance(obj.created, datetime.datetime)
@skipUnlessDBFeature('can_return_multiple_columns_from_insert')
def test_insert_returning_multiple(self):
with CaptureQueriesContext(connection) as captured_queries:
obj = ReturningModel.objects.create()
table_name = connection.ops.quote_name(ReturningModel._meta.db_table)
self.assertIn(
'RETURNING %s.%s, %s.%s' % (
table_name,
connection.ops.quote_name(ReturningModel._meta.get_field('id').column),
table_name,
connection.ops.quote_name(ReturningModel._meta.get_field('created').column),
),
captured_queries[-1]['sql'],
)
self.assertTrue(obj.pk)
self.assertIsInstance(obj.created, datetime.datetime)
@skipIfDBFeature('can_return_multiple_columns_from_insert')
def test_insert_returning_multiple_not_supported(self):
msg = (
'Returning multiple columns from INSERT statements is '
'not supported on this database backend.'
)
with self.assertRaisesMessage(NotSupportedError, msg):
ReturningModel.objects.create()
@skipUnlessDBFeature(
'can_return_rows_from_bulk_insert',
'can_return_multiple_columns_from_insert',
)
def test_bulk_insert(self):
objs = [ReturningModel(), ReturningModel(pk=2 ** 11), ReturningModel()]
ReturningModel.objects.bulk_create(objs)
for obj in objs:
with self.subTest(obj=obj):
self.assertTrue(obj.pk)
self.assertIsInstance(obj.created, datetime.datetime)