Fixed #10473: Added Oracle support for "RETURNING" ids from insert statements.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@10044 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
6d17020c1a
commit
c3dc837950
|
@ -162,6 +162,14 @@ class BaseDatabaseOperations(object):
|
|||
"""
|
||||
return None
|
||||
|
||||
def fetch_returned_insert_id(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table that has an auto-incrementing ID, returns the
|
||||
newly created ID.
|
||||
"""
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def field_cast_sql(self, db_type):
|
||||
"""
|
||||
Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary
|
||||
|
@ -249,10 +257,10 @@ class BaseDatabaseOperations(object):
|
|||
|
||||
def return_insert_id(self):
|
||||
"""
|
||||
For backends that support returning the last insert ID as part of an
|
||||
insert query, this method returns the SQL to append to the INSERT
|
||||
query. The returned fragment should contain a format string to hold
|
||||
hold the appropriate column.
|
||||
For backends that support returning the last insert ID as part
|
||||
of an insert query, this method returns the SQL and params to
|
||||
append to the INSERT query. The returned fragment should
|
||||
contain a format string to hold the appropriate column.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
uses_custom_query_class = True
|
||||
interprets_empty_strings_as_nulls = True
|
||||
uses_savepoints = True
|
||||
can_return_id_from_insert = True
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
|
@ -97,6 +98,9 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
def drop_sequence_sql(self, table):
|
||||
return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table))
|
||||
|
||||
def fetch_returned_insert_id(self, cursor):
|
||||
return long(cursor._insert_id_var.getvalue())
|
||||
|
||||
def field_cast_sql(self, db_type):
|
||||
if db_type and db_type.endswith('LOB'):
|
||||
return "DBMS_LOB.SUBSTR(%s)"
|
||||
|
@ -152,6 +156,9 @@ WHEN (new.%(col_name)s IS NULL)
|
|||
connection.cursor()
|
||||
return connection.ops.regex_lookup(lookup_type)
|
||||
|
||||
def return_insert_id(self):
|
||||
return "RETURNING %s INTO %%s", (InsertIdVar(),)
|
||||
|
||||
def savepoint_create_sql(self, sid):
|
||||
return "SAVEPOINT " + self.quote_name(sid)
|
||||
|
||||
|
@ -332,8 +339,11 @@ class OracleParam(object):
|
|||
parameter when executing the query.
|
||||
"""
|
||||
|
||||
def __init__(self, param, charset, strings_only=False):
|
||||
self.smart_str = smart_str(param, charset, strings_only)
|
||||
def __init__(self, param, cursor, strings_only=False):
|
||||
if hasattr(param, 'bind_parameter'):
|
||||
self.smart_str = param.bind_parameter(cursor)
|
||||
else:
|
||||
self.smart_str = smart_str(param, cursor.charset, strings_only)
|
||||
if hasattr(param, 'input_size'):
|
||||
# If parameter has `input_size` attribute, use that.
|
||||
self.input_size = param.input_size
|
||||
|
@ -344,6 +354,19 @@ class OracleParam(object):
|
|||
self.input_size = None
|
||||
|
||||
|
||||
class InsertIdVar(object):
|
||||
"""
|
||||
A late-binding cursor variable that can be passed to Cursor.execute
|
||||
as a parameter, in order to receive the id of the row created by an
|
||||
insert statement.
|
||||
"""
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
param = cursor.var(Database.NUMBER)
|
||||
cursor._insert_id_var = param
|
||||
return param
|
||||
|
||||
|
||||
class FormatStylePlaceholderCursor(object):
|
||||
"""
|
||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||
|
@ -363,7 +386,7 @@ class FormatStylePlaceholderCursor(object):
|
|||
self.cursor.arraysize = 100
|
||||
|
||||
def _format_params(self, params):
|
||||
return tuple([OracleParam(p, self.charset, True) for p in params])
|
||||
return tuple([OracleParam(p, self, True) for p in params])
|
||||
|
||||
def _guess_input_sizes(self, params_list):
|
||||
sizes = [None] * len(params_list[0])
|
||||
|
|
|
@ -39,7 +39,7 @@ class DatabaseOperations(PostgresqlDatabaseOperations):
|
|||
return cursor.query
|
||||
|
||||
def return_insert_id(self):
|
||||
return "RETURNING %s"
|
||||
return "RETURNING %s", ()
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
operators = {
|
||||
|
|
|
@ -306,17 +306,20 @@ class InsertQuery(Query):
|
|||
result = ['INSERT INTO %s' % qn(opts.db_table)]
|
||||
result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
|
||||
result.append('VALUES (%s)' % ', '.join(self.values))
|
||||
params = self.params
|
||||
if self.connection.features.can_return_id_from_insert:
|
||||
col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
|
||||
result.append(self.connection.ops.return_insert_id() % col)
|
||||
return ' '.join(result), self.params
|
||||
r_fmt, r_params = self.connection.ops.return_insert_id()
|
||||
result.append(r_fmt % col)
|
||||
params = params + r_params
|
||||
return ' '.join(result), params
|
||||
|
||||
def execute_sql(self, return_id=False):
|
||||
cursor = super(InsertQuery, self).execute_sql(None)
|
||||
if not (return_id and cursor):
|
||||
return
|
||||
if self.connection.features.can_return_id_from_insert:
|
||||
return cursor.fetchone()[0]
|
||||
return self.connection.ops.fetch_returned_insert_id(cursor)
|
||||
return self.connection.ops.last_insert_id(cursor,
|
||||
self.model._meta.db_table, self.model._meta.pk.column)
|
||||
|
||||
|
|
Loading…
Reference in New Issue