From 157604a87fa7e1331c25fcbed558f0799aa5b8df Mon Sep 17 00:00:00 2001 From: Andrew Godwin Date: Tue, 13 Aug 2013 20:54:57 +0100 Subject: [PATCH] Oracle schema backend, passes most tests and is pretty complete. --- django/db/backends/oracle/base.py | 1 + django/db/backends/oracle/introspection.py | 139 +++++++++++++++++++++ django/db/backends/oracle/schema.py | 77 ++++++++++++ django/db/backends/schema.py | 35 ++++-- tests/schema/tests.py | 4 +- 5 files changed, 247 insertions(+), 9 deletions(-) diff --git a/django/db/backends/oracle/base.py b/django/db/backends/oracle/base.py index 9b08bc097a..b6812a6d3e 100644 --- a/django/db/backends/oracle/base.py +++ b/django/db/backends/oracle/base.py @@ -92,6 +92,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_tablespaces = True supports_sequence_reset = False supports_combined_alters = False + max_index_name_length = 30 class DatabaseOperations(BaseDatabaseOperations): diff --git a/django/db/backends/oracle/introspection.py b/django/db/backends/oracle/introspection.py index a2fad92509..e4ef1ae81b 100644 --- a/django/db/backends/oracle/introspection.py +++ b/django/db/backends/oracle/introspection.py @@ -134,3 +134,142 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): indexes[row[0]] = {'primary_key': bool(row[1]), 'unique': bool(row[2])} return indexes + + def get_constraints(self, cursor, table_name): + """ + Retrieves any constraints or keys (unique, pk, fk, check, index) across one or more columns. + """ + constraints = {} + # Loop over the constraints, getting PKs and uniques + cursor.execute(""" + SELECT + user_constraints.constraint_name, + LOWER(cols.column_name) AS column_name, + CASE user_constraints.constraint_type + WHEN 'P' THEN 1 + ELSE 0 + END AS is_primary_key, + CASE user_indexes.uniqueness + WHEN 'UNIQUE' THEN 1 + ELSE 0 + END AS is_unique, + CASE user_constraints.constraint_type + WHEN 'C' THEN 1 + ELSE 0 + END AS is_check_constraint + FROM + user_constraints + INNER JOIN + user_indexes ON user_indexes.index_name = user_constraints.index_name + LEFT OUTER JOIN + user_cons_columns cols ON user_constraints.constraint_name = cols.constraint_name + WHERE + ( + user_constraints.constraint_type = 'P' OR + user_constraints.constraint_type = 'U' + ) + AND user_constraints.table_name = UPPER(%s) + ORDER BY cols.position + """, [table_name]) + for constraint, column, pk, unique, check in cursor.fetchall(): + # If we're the first column, make the record + if constraint not in constraints: + constraints[constraint] = { + "columns": [], + "primary_key": pk, + "unique": unique, + "foreign_key": None, + "check": check, + "index": True, + } + # Record the details + constraints[constraint]['columns'].append(column) + # Check constraints + cursor.execute(""" + SELECT + cons.constraint_name, + LOWER(cols.column_name) AS column_name + FROM + user_constraints cons + LEFT OUTER JOIN + user_cons_columns cols ON cons.constraint_name = cols.constraint_name + WHERE + cons.constraint_type = 'C' AND + cons.table_name = UPPER(%s) + ORDER BY cols.position + """, [table_name]) + for constraint, column in cursor.fetchall(): + # If we're the first column, make the record + if constraint not in constraints: + constraints[constraint] = { + "columns": [], + "primary_key": False, + "unique": False, + "foreign_key": None, + "check": True, + "index": False, + } + # Record the details + constraints[constraint]['columns'].append(column) + # Foreign key constraints + cursor.execute(""" + SELECT + cons.constraint_name, + LOWER(cols.column_name) AS column_name, + LOWER(rcons.table_name), + LOWER(rcols.column_name) + FROM + user_constraints cons + INNER JOIN + user_constraints rcons ON cons.r_constraint_name = rcons.constraint_name + INNER JOIN + user_cons_columns rcols ON rcols.constraint_name = rcons.constraint_name + LEFT OUTER JOIN + user_cons_columns cols ON cons.constraint_name = cols.constraint_name + WHERE + cons.constraint_type = 'R' AND + cons.table_name = UPPER(%s) + ORDER BY cols.position + """, [table_name]) + for constraint, column, other_table, other_column in cursor.fetchall(): + # If we're the first column, make the record + if constraint not in constraints: + constraints[constraint] = { + "columns": [], + "primary_key": False, + "unique": False, + "foreign_key": (other_table, other_column), + "check": False, + "index": False, + } + # Record the details + constraints[constraint]['columns'].append(column) + # Now get indexes + cursor.execute(""" + SELECT + index_name, + LOWER(column_name) + FROM + user_ind_columns cols + WHERE + table_name = UPPER(%s) AND + NOT EXISTS ( + SELECT 1 + FROM user_constraints cons + WHERE cols.index_name = cons.index_name + ) + """, [table_name]) + for constraint, column in cursor.fetchall(): + # If we're the first column, make the record + if constraint not in constraints: + constraints[constraint] = { + "columns": [], + "primary_key": False, + "unique": False, + "foreign_key": None, + "check": False, + "index": True, + } + # Record the details + constraints[constraint]['columns'].append(column) + return constraints diff --git a/django/db/backends/oracle/schema.py b/django/db/backends/oracle/schema.py index 4a679e79eb..c78294cad5 100644 --- a/django/db/backends/oracle/schema.py +++ b/django/db/backends/oracle/schema.py @@ -1,4 +1,6 @@ +import copy from django.db.backends.schema import BaseDatabaseSchemaEditor +from django.db.utils import DatabaseError class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): @@ -12,3 +14,78 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s" sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS" + def delete_model(self, model): + # Run superclass action + super(DatabaseSchemaEditor, self).delete_model(model) + # Clean up any autoincrement trigger + self.execute(""" + DECLARE + i INTEGER; + BEGIN + SELECT COUNT(*) INTO i FROM USER_CATALOG + WHERE TABLE_NAME = '%(sq_name)s' AND TABLE_TYPE = 'SEQUENCE'; + IF i = 1 THEN + EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"'; + END IF; + END; + /""" % {'sq_name': self.connection.ops._get_sequence_name(model._meta.db_table)}) + + def alter_field(self, model, old_field, new_field, strict=False): + try: + # Run superclass action + super(DatabaseSchemaEditor, self).alter_field(model, old_field, new_field, strict) + except DatabaseError as e: + description = str(e) + # If we're changing to/from LOB fields, we need to do a + # SQLite-ish workaround + if 'ORA-22858' in description or 'ORA-22859' in description: + self._alter_field_lob_workaround(model, old_field, new_field) + else: + raise + + def _alter_field_lob_workaround(self, model, old_field, new_field): + """ + Oracle refuses to change a column type from/to LOB to/from a regular + column. In Django, this shows up when the field is changed from/to + a TextField. + What we need to do instead is: + - Add the desired field with a temporary name + - Update the table to transfer values from old to new + - Drop old column + - Rename the new column + """ + # Make a new field that's like the new one but with a temporary + # column name. + new_temp_field = copy.deepcopy(new_field) + new_temp_field.column = self._generate_temp_name(new_field.column) + # Add it + self.add_field(model, new_temp_field) + # Transfer values across + self.execute("UPDATE %s set %s=%s" % ( + self.quote_name(model._meta.db_table), + self.quote_name(new_temp_field.column), + self.quote_name(old_field.column), + )) + # Drop the old field + self.remove_field(model, old_field) + # Rename the new field + self.alter_field(model, new_temp_field, new_field) + # Close the connection to force cx_Oracle to get column types right + # on a new cursor + self.connection.close() + + def normalize_name(self, name): + """ + Get the properly shortened and uppercased identifier as returned by quote_name(), but without the actual quotes. + """ + nn = self.quote_name(name) + if nn[0] == '"' and nn[-1] == '"': + nn = nn[1:-1] + return nn + + def _generate_temp_name(self, for_name): + """ + Generates temporary names for workarounds that need temp columns + """ + suffix = hex(hash(for_name)).upper()[1:] + return self.normalize_name(for_name + "_" + suffix) diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index 19a737883f..7beae7417a 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -113,6 +113,11 @@ class BaseDatabaseSchemaEditor(object): sql += " %s" % self.connection.ops.tablespace_sql(tablespace, inline=True) # Work out nullability null = field.null + # If we were told to include a default value, do so + default_value = self.effective_default(field) + if include_default and default_value is not None: + sql += " DEFAULT %s" + params += [default_value] # Oracle treats the empty string ('') as null, so coerce the null # option whenever '' is a possible value. if (field.empty_strings_allowed and not field.primary_key and @@ -127,11 +132,6 @@ class BaseDatabaseSchemaEditor(object): sql += " PRIMARY KEY" elif field.unique: sql += " UNIQUE" - # If we were told to include a default value, do so - default_value = self.effective_default(field) - if include_default and default_value is not None: - sql += " DEFAULT %s" - params += [default_value] # Return the sql return sql, params @@ -176,7 +176,7 @@ class BaseDatabaseSchemaEditor(object): )) params.extend(extra_params) # Indexes - if field.db_index: + if field.db_index and not field.unique: self.deferred_sql.append( self.sql_create_index % { "name": self._create_index_name(model, [field.column], suffix=""), @@ -198,6 +198,11 @@ class BaseDatabaseSchemaEditor(object): "to_column": self.quote_name(to_column), } ) + # Autoincrement SQL + if field.get_internal_type() == "AutoField": + autoinc_sql = self.connection.ops.autoinc_sql(model._meta.db_table, field.column) + if autoinc_sql: + self.deferred_sql.extend(autoinc_sql) # Add any unique_togethers for fields in model._meta.unique_together: columns = [model._meta.get_field_by_name(field)[0].column for field in fields] @@ -353,6 +358,16 @@ class BaseDatabaseSchemaEditor(object): } } self.execute(sql) + # Add an index, if required + if field.db_index and not field.unique: + self.deferred_sql.append( + self.sql_create_index % { + "name": self._create_index_name(model, [field.column], suffix=""), + "table": self.quote_name(model._meta.db_table), + "columns": self.quote_name(field.column), + "extra": "", + } + ) # Add any FK constraints later if field.rel and self.connection.features.supports_foreign_keys: to_table = field.rel.to._meta.db_table @@ -412,7 +427,7 @@ class BaseDatabaseSchemaEditor(object): new_field, )) # Has unique been removed? - if old_field.unique and not new_field.unique: + if old_field.unique and (not new_field.unique or (not old_field.primary_key and new_field.primary_key)): # Find the unique constraint for this field constraint_names = self._constraint_names(model, [old_field.column], unique=True) if strict and len(constraint_names) != 1: @@ -647,9 +662,15 @@ class BaseDatabaseSchemaEditor(object): if len(index_name) > self.connection.features.max_index_name_length: part = ('_%s%s%s' % (column_names[0], index_unique_name, suffix)) index_name = '%s%s' % (table_name[:(self.connection.features.max_index_name_length - len(part))], part) + # It shouldn't start with an underscore (Oracle hates this) + if index_name[0] == "_": + index_name = index_name[1:] # If it's STILL too long, just hash it down if len(index_name) > self.connection.features.max_index_name_length: index_name = hashlib.md5(index_name).hexdigest()[:self.connection.features.max_index_name_length] + # It can't start with a number on Oracle, so prepend D if we need to + if index_name[0].isdigit(): + index_name = "D%s" % index_name[:-1] return index_name def _constraint_names(self, model, column_names=None, unique=None, primary_key=None, index=None, foreign_key=None, check=None): diff --git a/tests/schema/tests.py b/tests/schema/tests.py index f6e45599b8..d4e76e8567 100644 --- a/tests/schema/tests.py +++ b/tests/schema/tests.py @@ -167,7 +167,7 @@ class SchemaTests(TransactionTestCase): # Ensure the field is right to begin with columns = self.column_classes(Author) self.assertEqual(columns['name'][0], "CharField") - self.assertEqual(columns['name'][1][6], False) + self.assertEqual(bool(columns['name'][1][6]), bool(connection.features.interprets_empty_strings_as_nulls)) # Alter the name field to a TextField new_field = TextField(null=True) new_field.set_attributes_from_name("name") @@ -195,7 +195,7 @@ class SchemaTests(TransactionTestCase): # Ensure the field is right afterwards columns = self.column_classes(Author) self.assertEqual(columns['name'][0], "TextField") - self.assertEqual(columns['name'][1][6], False) + self.assertEqual(columns['name'][1][6], bool(connection.features.interprets_empty_strings_as_nulls)) def test_rename(self): """