diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 8a718dba9a..1c057c3358 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -162,6 +162,14 @@ class BaseDatabaseOperations(object): """ return None + def fetch_returned_insert_id(self, cursor): + """ + Given a cursor object that has just performed an INSERT...RETURNING + statement into a table that has an auto-incrementing ID, returns the + newly created ID. + """ + return cursor.fetchone()[0] + def field_cast_sql(self, db_type): """ Given a column type (e.g. 'BLOB', 'VARCHAR'), returns the SQL necessary @@ -249,10 +257,10 @@ class BaseDatabaseOperations(object): def return_insert_id(self): """ - For backends that support returning the last insert ID as part of an - insert query, this method returns the SQL to append to the INSERT - query. The returned fragment should contain a format string to hold - hold the appropriate column. + For backends that support returning the last insert ID as part + of an insert query, this method returns the SQL and params to + append to the INSERT query. The returned fragment should + contain a format string to hold the appropriate column. """ pass diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index e8570c9fce..109ec5bd9f 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -37,6 +37,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): uses_custom_query_class = True interprets_empty_strings_as_nulls = True uses_savepoints = True + can_return_id_from_insert = True class DatabaseOperations(BaseDatabaseOperations): @@ -97,6 +98,9 @@ WHEN (new.%(col_name)s IS NULL) def drop_sequence_sql(self, table): return "DROP SEQUENCE %s;" % self.quote_name(get_sequence_name(table)) + def fetch_returned_insert_id(self, cursor): + return long(cursor._insert_id_var.getvalue()) + def field_cast_sql(self, db_type): if db_type and db_type.endswith('LOB'): return "DBMS_LOB.SUBSTR(%s)" @@ -152,6 +156,9 @@ WHEN (new.%(col_name)s IS NULL) connection.cursor() return connection.ops.regex_lookup(lookup_type) + def return_insert_id(self): + return "RETURNING %s INTO %%s", (InsertIdVar(),) + def savepoint_create_sql(self, sid): return "SAVEPOINT " + self.quote_name(sid) @@ -332,8 +339,11 @@ class OracleParam(object): parameter when executing the query. """ - def __init__(self, param, charset, strings_only=False): - self.smart_str = smart_str(param, charset, strings_only) + def __init__(self, param, cursor, strings_only=False): + if hasattr(param, 'bind_parameter'): + self.smart_str = param.bind_parameter(cursor) + else: + self.smart_str = smart_str(param, cursor.charset, strings_only) if hasattr(param, 'input_size'): # If parameter has `input_size` attribute, use that. self.input_size = param.input_size @@ -344,6 +354,19 @@ class OracleParam(object): self.input_size = None +class InsertIdVar(object): + """ + A late-binding cursor variable that can be passed to Cursor.execute + as a parameter, in order to receive the id of the row created by an + insert statement. + """ + + def bind_parameter(self, cursor): + param = cursor.var(Database.NUMBER) + cursor._insert_id_var = param + return param + + class FormatStylePlaceholderCursor(object): """ Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var" @@ -363,7 +386,7 @@ class FormatStylePlaceholderCursor(object): self.cursor.arraysize = 100 def _format_params(self, params): - return tuple([OracleParam(p, self.charset, True) for p in params]) + return tuple([OracleParam(p, self, True) for p in params]) def _guess_input_sizes(self, params_list): sizes = [None] * len(params_list[0]) diff --git a/django/db/backends/postgresql_psycopg2/base.py b/django/db/backends/postgresql_psycopg2/base.py index ea2f9f41b1..db4369f038 100644 --- a/django/db/backends/postgresql_psycopg2/base.py +++ b/django/db/backends/postgresql_psycopg2/base.py @@ -39,7 +39,7 @@ class DatabaseOperations(PostgresqlDatabaseOperations): return cursor.query def return_insert_id(self): - return "RETURNING %s" + return "RETURNING %s", () class DatabaseWrapper(BaseDatabaseWrapper): operators = { diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index f421643dc8..aabe34a4ca 100644 --- a/django/db/models/sql/subqueries.py +++ b/django/db/models/sql/subqueries.py @@ -306,17 +306,20 @@ class InsertQuery(Query): result = ['INSERT INTO %s' % qn(opts.db_table)] result.append('(%s)' % ', '.join([qn(c) for c in self.columns])) result.append('VALUES (%s)' % ', '.join(self.values)) + params = self.params if self.connection.features.can_return_id_from_insert: col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - result.append(self.connection.ops.return_insert_id() % col) - return ' '.join(result), self.params + r_fmt, r_params = self.connection.ops.return_insert_id() + result.append(r_fmt % col) + params = params + r_params + return ' '.join(result), params def execute_sql(self, return_id=False): cursor = super(InsertQuery, self).execute_sql(None) if not (return_id and cursor): return if self.connection.features.can_return_id_from_insert: - return cursor.fetchone()[0] + return self.connection.ops.fetch_returned_insert_id(cursor) return self.connection.ops.last_insert_id(cursor, self.model._meta.db_table, self.model._meta.pk.column)