Fixed #10473: Added Oracle support for "RETURNING" ids from insert statements.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@10044 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
6d17020c1a
commit
c3dc837950
|
@ -162,6 +162,14 @@ class BaseDatabaseOperations(object):
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def fetch_returned_insert_id(self, cursor):
|
||||||
|
"""
|
||||||
|
Given a cursor object that has just performed an INSERT...RETURNING
|
||||||
|
statement into a table that has an auto-incrementing ID, returns the
|
||||||
|
newly created ID.
|
||||||
|
"""
|
||||||
|
return cursor.fetchone()[0]
|
||||||
|
|
||||||
def field_cast_sql(self, db_type):
|
def field_cast_sql(self, db_type):
|
||||||
"""
|
"""
|
||||||
Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
|
Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
|
||||||
|
@ -249,10 +257,10 @@ class BaseDatabaseOperations(object):
|
||||||
|
|
||||||
def return_insert_id(self):
|
def return_insert_id(self):
|
||||||
"""
|
"""
|
||||||
For backends that support returning the last insert ID as part of an
|
For backends that support returning the last insert ID as part
|
||||||
insert query, this method returns the SQL to append to the INSERT
|
of an insert query, this method returns the SQL and params to
|
||||||
query. The returned fragment should contain a format string to hold
|
append to the INSERT query. The returned fragment should
|
||||||
hold the appropriate column.
|
contain a format string to hold the appropriate column.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
||||||
uses_custom_query_class = True
|
uses_custom_query_class = True
|
||||||
interprets_empty_strings_as_nulls = True
|
interprets_empty_strings_as_nulls = True
|
||||||
uses_savepoints = True
|
uses_savepoints = True
|
||||||
|
can_return_id_from_insert = True
|
||||||
|
|
||||||
|
|
||||||
class DatabaseOperations(BaseDatabaseOperations):
|
class DatabaseOperations(BaseDatabaseOperations):
|
||||||
|
@ -97,6 +98,9 @@ WHEN (new.%(col_name)s IS NULL)
|
||||||
def drop_sequence_sql(self, table):
|
def drop_sequence_sql(self, table):
|
||||||
return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
|
return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
|
||||||
|
|
||||||
|
def fetch_returned_insert_id(self, cursor):
|
||||||
|
return long(cursor._insert_id_var.getvalue())
|
||||||
|
|
||||||
def field_cast_sql(self, db_type):
|
def field_cast_sql(self, db_type):
|
||||||
if db_type and db_type.endswith('LOB'):
|
if db_type and db_type.endswith('LOB'):
|
||||||
return "DBMS_LOB.SUBSTR(%s)"
|
return "DBMS_LOB.SUBSTR(%s)"
|
||||||
|
@ -152,6 +156,9 @@ WHEN (new.%(col_name)s IS NULL)
|
||||||
connection.cursor()
|
connection.cursor()
|
||||||
return connection.ops.regex_lookup(lookup_type)
|
return connection.ops.regex_lookup(lookup_type)
|
||||||
|
|
||||||
|
def return_insert_id(self):
|
||||||
|
return "RETURNING %s INTO %%s", (InsertIdVar(),)
|
||||||
|
|
||||||
def savepoint_create_sql(self, sid):
|
def savepoint_create_sql(self, sid):
|
||||||
return "SAVEPOINT " + self.quote_name(sid)
|
return "SAVEPOINT " + self.quote_name(sid)
|
||||||
|
|
||||||
|
@ -332,8 +339,11 @@ class OracleParam(object):
|
||||||
parameter when executing the query.
|
parameter when executing the query.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, param, charset, strings_only=False):
|
def __init__(self, param, cursor, strings_only=False):
|
||||||
self.smart_str = smart_str(param, charset, strings_only)
|
if hasattr(param, 'bind_parameter'):
|
||||||
|
self.smart_str = param.bind_parameter(cursor)
|
||||||
|
else:
|
||||||
|
self.smart_str = smart_str(param, cursor.charset, strings_only)
|
||||||
if hasattr(param, 'input_size'):
|
if hasattr(param, 'input_size'):
|
||||||
# If parameter has `input_size` attribute, use that.
|
# If parameter has `input_size` attribute, use that.
|
||||||
self.input_size = param.input_size
|
self.input_size = param.input_size
|
||||||
|
@ -344,6 +354,19 @@ class OracleParam(object):
|
||||||
self.input_size = None
|
self.input_size = None
|
||||||
|
|
||||||
|
|
||||||
|
class InsertIdVar(object):
|
||||||
|
"""
|
||||||
|
A late-binding cursor variable that can be passed to Cursor.execute
|
||||||
|
as a parameter, in order to receive the id of the row created by an
|
||||||
|
insert statement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def bind_parameter(self, cursor):
|
||||||
|
param = cursor.var(Database.NUMBER)
|
||||||
|
cursor._insert_id_var = param
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
class FormatStylePlaceholderCursor(object):
|
class FormatStylePlaceholderCursor(object):
|
||||||
"""
|
"""
|
||||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||||
|
@ -363,7 +386,7 @@ class FormatStylePlaceholderCursor(object):
|
||||||
self.cursor.arraysize = 100
|
self.cursor.arraysize = 100
|
||||||
|
|
||||||
def _format_params(self, params):
|
def _format_params(self, params):
|
||||||
return tuple([OracleParam(p, self.charset, True) for p in params])
|
return tuple([OracleParam(p, self, True) for p in params])
|
||||||
|
|
||||||
def _guess_input_sizes(self, params_list):
|
def _guess_input_sizes(self, params_list):
|
||||||
sizes = [None] * len(params_list[0])
|
sizes = [None] * len(params_list[0])
|
||||||
|
|
|
@ -39,7 +39,7 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
|
||||||
return cursor.query
|
return cursor.query
|
||||||
|
|
||||||
def return_insert_id(self):
|
def return_insert_id(self):
|
||||||
return "RETURNING %s"
|
return "RETURNING %s", ()
|
||||||
|
|
||||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||||
operators = {
|
operators = {
|
||||||
|
|
|
@ -306,17 +306,20 @@ class InsertQuery(Query):
|
||||||
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
||||||
result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
|
result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
|
||||||
result.append('VALUES (%s)' % ', '.join(self.values))
|
result.append('VALUES (%s)' % ', '.join(self.values))
|
||||||
|
params = self.params
|
||||||
if self.connection.features.can_return_id_from_insert:
|
if self.connection.features.can_return_id_from_insert:
|
||||||
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
||||||
result.append(self.connection.ops.return_insert_id() % col)
|
r_fmt, r_params = self.connection.ops.return_insert_id()
|
||||||
return ' '.join(result), self.params
|
result.append(r_fmt % col)
|
||||||
|
params = params + r_params
|
||||||
|
return ' '.join(result), params
|
||||||
|
|
||||||
def execute_sql(self, return_id=False):
|
def execute_sql(self, return_id=False):
|
||||||
cursor = super(InsertQuery, self).execute_sql(None)
|
cursor = super(InsertQuery, self).execute_sql(None)
|
||||||
if not (return_id and cursor):
|
if not (return_id and cursor):
|
||||||
return
|
return
|
||||||
if self.connection.features.can_return_id_from_insert:
|
if self.connection.features.can_return_id_from_insert:
|
||||||
return cursor.fetchone()[0]
|
return self.connection.ops.fetch_returned_insert_id(cursor)
|
||||||
return self.connection.ops.last_insert_id(cursor,
|
return self.connection.ops.last_insert_id(cursor,
|
||||||
self.model._meta.db_table, self.model._meta.pk.column)
|
self.model._meta.db_table, self.model._meta.pk.column)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue