Fixed #29444 -- Allowed returning multiple fields from INSERT statements on Oracle.

This commit is contained in:
Johannes Hoppe 2019-09-14 14:50:14 -07:00 committed by Mariusz Felisiak
parent d71497bb24
commit b31e63879e
6 changed files with 39 additions and 24 deletions

View File

@ -176,7 +176,7 @@ class BaseDatabaseOperations:
else:
return ['DISTINCT'], []
def fetch_returned_insert_columns(self, cursor):
def fetch_returned_insert_columns(self, cursor, returning_params):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table, return the newly created data.

View File

@ -10,6 +10,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
has_select_for_update_of = True
select_for_update_of_column = True
can_return_columns_from_insert = True
can_return_multiple_columns_from_insert = True
can_introspect_autofield = True
supports_subqueries_in_group_by = False
supports_transactions = True

View File

@ -248,17 +248,19 @@ END;
def deferrable_sql(self):
return " DEFERRABLE INITIALLY DEFERRED"
def fetch_returned_insert_columns(self, cursor):
value = cursor._insert_id_var.getvalue()
if value is None or value == []:
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
raise DatabaseError(
'The database did not return a new row id. Probably "ORA-1403: '
'no data found" was raised internally but was hidden by the '
'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).'
)
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
return value if isinstance(value, list) else [value]
def fetch_returned_insert_columns(self, cursor, returning_params):
for param in returning_params:
value = param.get_value()
if value is None or value == []:
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
raise DatabaseError(
'The database did not return a new row id. Probably '
'"ORA-1403: no data found" was raised internally but was '
'hidden by the Oracle OCI library (see '
'https://code.djangoproject.com/ticket/28859).'
)
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
yield value[0] if isinstance(value, list) else value
def field_cast_sql(self, db_type, internal_type):
if db_type and db_type.endswith('LOB'):
@ -344,11 +346,18 @@ END;
def return_insert_columns(self, fields):
if not fields:
return '', ()
sql = 'RETURNING %s.%s INTO %%s' % (
self.quote_name(fields[0].model._meta.db_table),
self.quote_name(fields[0].column),
)
return sql, (InsertVar(fields[0]),)
field_names = []
params = []
for field in fields:
field_names.append('%s.%s' % (
self.quote_name(field.model._meta.db_table),
self.quote_name(field.column),
))
params.append(InsertVar(field))
return 'RETURNING %s INTO %s' % (
', '.join(field_names),
', '.join(['%s'] * len(params)),
), tuple(params)
def __foreign_key_constraints(self, table_name, recursive):
with self.connection.cursor() as cursor:

View File

@ -27,11 +27,14 @@ class InsertVar:
def __init__(self, field):
internal_type = getattr(field, 'target_field', field).get_internal_type()
self.db_type = self.types.get(internal_type, str)
self.bound_param = None
def bind_parameter(self, cursor):
param = cursor.cursor.var(self.db_type)
cursor._insert_id_var = param
return param
self.bound_param = cursor.cursor.var(self.db_type)
return self.bound_param
def get_value(self):
return self.bound_param.getvalue()
class Oracle_datetime(datetime.datetime):

View File

@ -1152,6 +1152,7 @@ class SQLCompiler:
class SQLInsertCompiler(SQLCompiler):
returning_fields = None
returning_params = tuple()
def field_as_sql(self, field, val):
"""
@ -1300,10 +1301,10 @@ class SQLInsertCompiler(SQLCompiler):
result.append(ignore_conflicts_suffix_sql)
# Skip empty r_sql to allow subclasses to customize behavior for
# 3rd party backends. Refs #19096.
r_sql, r_params = self.connection.ops.return_insert_columns(self.returning_fields)
r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
if r_sql:
result.append(r_sql)
params += [r_params]
params += [self.returning_params]
return [(" ".join(result), tuple(chain.from_iterable(params)))]
if can_bulk:
@ -1342,7 +1343,7 @@ class SQLInsertCompiler(SQLCompiler):
'not supported on this database backend.'
)
assert len(self.query.objs) == 1
return self.connection.ops.fetch_returned_insert_columns(cursor)
return self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)
return [self.connection.ops.last_insert_id(
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
)]

View File

@ -208,7 +208,8 @@ Database backend API
This section describes changes that may be needed in third-party database
backends.
* ...
* ``DatabaseOperations.fetch_returned_insert_columns()`` now requires an
additional ``returning_params`` argument.
Miscellaneous
-------------