Fixed #10473: Added Oracle support for "RETURNING" ids from insert statements.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10044 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Ian Kelly 2009-03-12 23:41:27 +00:00
parent 6d17020c1a
commit c3dc837950
4 changed files with 45 additions and 11 deletions

View File

@ -162,6 +162,14 @@ class BaseDatabaseOperations(object):
""" """
return None return None
def fetch_returned_insert_id(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, returns the
newly created ID.
"""
return cursor.fetchone()[0]
def field_cast_sql(self, db_type): def field_cast_sql(self, db_type):
""" """
Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
@ -249,10 +257,10 @@ class BaseDatabaseOperations(object):
def return_insert_id(self): def return_insert_id(self):
""" """
For backends that support returning the last insert ID as part of an For backends that support returning the last insert ID as part
insert query, this method returns the SQL to append to the INSERT of an insert query, this method returns the SQL and params to
query. The returned fragment should contain a format string to hold append to the INSERT query. The returned fragment should
hold the appropriate column. contain a format string to hold the appropriate column.
""" """
pass pass

View File

@ -37,6 +37,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
uses_custom_query_class = True uses_custom_query_class = True
interprets_empty_strings_as_nulls = True interprets_empty_strings_as_nulls = True
uses_savepoints = True uses_savepoints = True
can_return_id_from_insert = True
class DatabaseOperations(BaseDatabaseOperations): class DatabaseOperations(BaseDatabaseOperations):
@ -97,6 +98,9 @@ WHEN (new.%(col_name)s IS NULL)
def drop_sequence_sql(self, table): def drop_sequence_sql(self, table):
return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table)) return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
def fetch_returned_insert_id(self, cursor):
return long(cursor._insert_id_var.getvalue())
def field_cast_sql(self, db_type): def field_cast_sql(self, db_type):
if db_type and db_type.endswith('LOB'): if db_type and db_type.endswith('LOB'):
return "DBMS_LOB.SUBSTR(%s)" return "DBMS_LOB.SUBSTR(%s)"
@ -152,6 +156,9 @@ WHEN (new.%(col_name)s IS NULL)
connection.cursor() connection.cursor()
return connection.ops.regex_lookup(lookup_type) return connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):
return "RETURNING %s INTO %%s", (InsertIdVar(),)
def savepoint_create_sql(self, sid): def savepoint_create_sql(self, sid):
return "SAVEPOINT " + self.quote_name(sid) return "SAVEPOINT " + self.quote_name(sid)
@ -332,8 +339,11 @@ class OracleParam(object):
parameter when executing the query. parameter when executing the query.
""" """
def __init__(self, param, charset, strings_only=False): def __init__(self, param, cursor, strings_only=False):
self.smart_str = smart_str(param, charset, strings_only) if hasattr(param, 'bind_parameter'):
self.smart_str = param.bind_parameter(cursor)
else:
self.smart_str = smart_str(param, cursor.charset, strings_only)
if hasattr(param, 'input_size'): if hasattr(param, 'input_size'):
# If parameter has `input_size` attribute, use that. # If parameter has `input_size` attribute, use that.
self.input_size = param.input_size self.input_size = param.input_size
@ -344,6 +354,19 @@ class OracleParam(object):
self.input_size = None self.input_size = None
class InsertIdVar(object):
"""
A late-binding cursor variable that can be passed to Cursor.execute
as a parameter, in order to receive the id of the row created by an
insert statement.
"""
def bind_parameter(self, cursor):
param = cursor.var(Database.NUMBER)
cursor._insert_id_var = param
return param
class FormatStylePlaceholderCursor(object): class FormatStylePlaceholderCursor(object):
""" """
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var" Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
@ -363,7 +386,7 @@ class FormatStylePlaceholderCursor(object):
self.cursor.arraysize = 100 self.cursor.arraysize = 100
def _format_params(self, params): def _format_params(self, params):
return tuple([OracleParam(p, self.charset, True) for p in params]) return tuple([OracleParam(p, self, True) for p in params])
def _guess_input_sizes(self, params_list): def _guess_input_sizes(self, params_list):
sizes = [None] * len(params_list[0]) sizes = [None] * len(params_list[0])

View File

@ -39,7 +39,7 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
return cursor.query return cursor.query
def return_insert_id(self): def return_insert_id(self):
return "RETURNING %s" return "RETURNING %s", ()
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
operators = { operators = {

View File

@ -306,17 +306,20 @@ class InsertQuery(Query):
result = ['INSERT INTO %s' % qn(opts.db_table)] result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.columns])) result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
result.append('VALUES (%s)' % ', '.join(self.values)) result.append('VALUES (%s)' % ', '.join(self.values))
params = self.params
if self.connection.features.can_return_id_from_insert: if self.connection.features.can_return_id_from_insert:
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append(self.connection.ops.return_insert_id() % col) r_fmt, r_params = self.connection.ops.return_insert_id()
return ' '.join(result), self.params result.append(r_fmt % col)
params = params + r_params
return ' '.join(result), params
def execute_sql(self, return_id=False): def execute_sql(self, return_id=False):
cursor = super(InsertQuery, self).execute_sql(None) cursor = super(InsertQuery, self).execute_sql(None)
if not (return_id and cursor): if not (return_id and cursor):
return return
if self.connection.features.can_return_id_from_insert: if self.connection.features.can_return_id_from_insert:
return cursor.fetchone()[0] return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor, return self.connection.ops.last_insert_id(cursor,
self.model._meta.db_table, self.model._meta.pk.column) self.model._meta.db_table, self.model._meta.pk.column)