Fixed #10473: Added Oracle support for "RETURNING" ids from insert statements.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@10044 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Ian Kelly 2009-03-12 23:41:27 +00:00
parent 6d17020c1a
commit c3dc837950
4 changed files with 45 additions and 11 deletions

View File

@ -162,6 +162,14 @@ class BaseDatabaseOperations(object):
"""
return None
def fetch_returned_insert_id(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, returns the
newly created ID.
"""
return cursor.fetchone()[0]
def field_cast_sql(self, db_type):
"""
Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
@ -249,10 +257,10 @@ class BaseDatabaseOperations(object):
def return_insert_id(self):
"""
For backends that support returning the last insert ID as part of an
insert query, this method returns the SQL to append to the INSERT
query. The returned fragment should contain a format string to hold
hold the appropriate column.
For backends that support returning the last insert ID as part
of an insert query, this method returns the SQL and params to
append to the INSERT query. The returned fragment should
contain a format string to hold the appropriate column.
"""
pass

View File

@ -37,6 +37,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
uses_custom_query_class = True
interprets_empty_strings_as_nulls = True
uses_savepoints = True
can_return_id_from_insert = True
class DatabaseOperations(BaseDatabaseOperations):
@ -97,6 +98,9 @@ WHEN (new.%(col_name)s IS NULL)
def drop_sequence_sql(self, table):
return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
def fetch_returned_insert_id(self, cursor):
return long(cursor._insert_id_var.getvalue())
def field_cast_sql(self, db_type):
if db_type and db_type.endswith('LOB'):
return "DBMS_LOB.SUBSTR(%s)"
@ -152,6 +156,9 @@ WHEN (new.%(col_name)s IS NULL)
connection.cursor()
return connection.ops.regex_lookup(lookup_type)
def return_insert_id(self):
return "RETURNING %s INTO %%s", (InsertIdVar(),)
def savepoint_create_sql(self, sid):
return "SAVEPOINT " + self.quote_name(sid)
@ -332,8 +339,11 @@ class OracleParam(object):
parameter when executing the query.
"""
def __init__(self, param, charset, strings_only=False):
self.smart_str = smart_str(param, charset, strings_only)
def __init__(self, param, cursor, strings_only=False):
if hasattr(param, 'bind_parameter'):
self.smart_str = param.bind_parameter(cursor)
else:
self.smart_str = smart_str(param, cursor.charset, strings_only)
if hasattr(param, 'input_size'):
# If parameter has `input_size` attribute, use that.
self.input_size = param.input_size
@ -344,6 +354,19 @@ class OracleParam(object):
self.input_size = None
class InsertIdVar(object):
"""
A late-binding cursor variable that can be passed to Cursor.execute
as a parameter, in order to receive the id of the row created by an
insert statement.
"""
def bind_parameter(self, cursor):
param = cursor.var(Database.NUMBER)
cursor._insert_id_var = param
return param
class FormatStylePlaceholderCursor(object):
"""
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
@ -363,7 +386,7 @@ class FormatStylePlaceholderCursor(object):
self.cursor.arraysize = 100
def _format_params(self, params):
return tuple([OracleParam(p, self.charset, True) for p in params])
return tuple([OracleParam(p, self, True) for p in params])
def _guess_input_sizes(self, params_list):
sizes = [None] * len(params_list[0])

View File

@ -39,7 +39,7 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
return cursor.query
def return_insert_id(self):
return "RETURNING %s"
return "RETURNING %s", ()
class DatabaseWrapper(BaseDatabaseWrapper):
operators = {

View File

@ -306,17 +306,20 @@ class InsertQuery(Query):
result = ['INSERT INTO %s' % qn(opts.db_table)]
result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
result.append('VALUES (%s)' % ', '.join(self.values))
params = self.params
if self.connection.features.can_return_id_from_insert:
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
result.append(self.connection.ops.return_insert_id() % col)
return ' '.join(result), self.params
r_fmt, r_params = self.connection.ops.return_insert_id()
result.append(r_fmt % col)
params = params + r_params
return ' '.join(result), params
def execute_sql(self, return_id=False):
cursor = super(InsertQuery, self).execute_sql(None)
if not (return_id and cursor):
return
if self.connection.features.can_return_id_from_insert:
return cursor.fetchone()[0]
return self.connection.ops.fetch_returned_insert_id(cursor)
return self.connection.ops.last_insert_id(cursor,
self.model._meta.db_table, self.model._meta.pk.column)