mirror of https://github.com/django/django.git
Fixed #29444 -- Allowed returning multiple fields from INSERT statements on Oracle.
This commit is contained in:
parent
d71497bb24
commit
b31e63879e
|
@ -176,7 +176,7 @@ class BaseDatabaseOperations:
|
|||
else:
|
||||
return ['DISTINCT'], []
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor):
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the newly created data.
|
||||
|
|
|
@ -10,6 +10,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
has_select_for_update_of = True
|
||||
select_for_update_of_column = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_multiple_columns_from_insert = True
|
||||
can_introspect_autofield = True
|
||||
supports_subqueries_in_group_by = False
|
||||
supports_transactions = True
|
||||
|
|
|
@ -248,17 +248,19 @@ END;
|
|||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor):
|
||||
value = cursor._insert_id_var.getvalue()
|
||||
if value is None or value == []:
|
||||
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
|
||||
raise DatabaseError(
|
||||
'The database did not return a new row id. Probably "ORA-1403: '
|
||||
'no data found" was raised internally but was hidden by the '
|
||||
'Oracle OCI library (see https://code.djangoproject.com/ticket/28859).'
|
||||
)
|
||||
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
|
||||
return value if isinstance(value, list) else [value]
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
for param in returning_params:
|
||||
value = param.get_value()
|
||||
if value is None or value == []:
|
||||
# cx_Oracle < 6.3 returns None, >= 6.3 returns empty list.
|
||||
raise DatabaseError(
|
||||
'The database did not return a new row id. Probably '
|
||||
'"ORA-1403: no data found" was raised internally but was '
|
||||
'hidden by the Oracle OCI library (see '
|
||||
'https://code.djangoproject.com/ticket/28859).'
|
||||
)
|
||||
# cx_Oracle < 7 returns value, >= 7 returns list with single value.
|
||||
yield value[0] if isinstance(value, list) else value
|
||||
|
||||
def field_cast_sql(self, db_type, internal_type):
|
||||
if db_type and db_type.endswith('LOB'):
|
||||
|
@ -344,11 +346,18 @@ END;
|
|||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return '', ()
|
||||
sql = 'RETURNING %s.%s INTO %%s' % (
|
||||
self.quote_name(fields[0].model._meta.db_table),
|
||||
self.quote_name(fields[0].column),
|
||||
)
|
||||
return sql, (InsertVar(fields[0]),)
|
||||
field_names = []
|
||||
params = []
|
||||
for field in fields:
|
||||
field_names.append('%s.%s' % (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
))
|
||||
params.append(InsertVar(field))
|
||||
return 'RETURNING %s INTO %s' % (
|
||||
', '.join(field_names),
|
||||
', '.join(['%s'] * len(params)),
|
||||
), tuple(params)
|
||||
|
||||
def __foreign_key_constraints(self, table_name, recursive):
|
||||
with self.connection.cursor() as cursor:
|
||||
|
|
|
@ -27,11 +27,14 @@ class InsertVar:
|
|||
def __init__(self, field):
|
||||
internal_type = getattr(field, 'target_field', field).get_internal_type()
|
||||
self.db_type = self.types.get(internal_type, str)
|
||||
self.bound_param = None
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
param = cursor.cursor.var(self.db_type)
|
||||
cursor._insert_id_var = param
|
||||
return param
|
||||
self.bound_param = cursor.cursor.var(self.db_type)
|
||||
return self.bound_param
|
||||
|
||||
def get_value(self):
|
||||
return self.bound_param.getvalue()
|
||||
|
||||
|
||||
class Oracle_datetime(datetime.datetime):
|
||||
|
|
|
@ -1152,6 +1152,7 @@ class SQLCompiler:
|
|||
|
||||
class SQLInsertCompiler(SQLCompiler):
|
||||
returning_fields = None
|
||||
returning_params = tuple()
|
||||
|
||||
def field_as_sql(self, field, val):
|
||||
"""
|
||||
|
@ -1300,10 +1301,10 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
result.append(ignore_conflicts_suffix_sql)
|
||||
# Skip empty r_sql to allow subclasses to customize behavior for
|
||||
# 3rd party backends. Refs #19096.
|
||||
r_sql, r_params = self.connection.ops.return_insert_columns(self.returning_fields)
|
||||
r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.returning_fields)
|
||||
if r_sql:
|
||||
result.append(r_sql)
|
||||
params += [r_params]
|
||||
params += [self.returning_params]
|
||||
return [(" ".join(result), tuple(chain.from_iterable(params)))]
|
||||
|
||||
if can_bulk:
|
||||
|
@ -1342,7 +1343,7 @@ class SQLInsertCompiler(SQLCompiler):
|
|||
'not supported on this database backend.'
|
||||
)
|
||||
assert len(self.query.objs) == 1
|
||||
return self.connection.ops.fetch_returned_insert_columns(cursor)
|
||||
return self.connection.ops.fetch_returned_insert_columns(cursor, self.returning_params)
|
||||
return [self.connection.ops.last_insert_id(
|
||||
cursor, self.query.get_meta().db_table, self.query.get_meta().pk.column
|
||||
)]
|
||||
|
|
|
@ -208,7 +208,8 @@ Database backend API
|
|||
This section describes changes that may be needed in third-party database
|
||||
backends.
|
||||
|
||||
* ...
|
||||
* ``DatabaseOperations.fetch_returned_insert_columns()`` now requires an
|
||||
additional ``returning_params`` argument.
|
||||
|
||||
Miscellaneous
|
||||
-------------
|
||||
|
|
Loading…
Reference in New Issue