diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index efea9f802e..39883de35c 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -432,6 +432,9 @@ class BaseDatabaseFeatures(object): # What's the maximum length for index names? max_index_name_length = 63 + # Does it support foreign keys? + supports_foreign_keys = True + def __init__(self, connection): self.connection = connection diff --git a/django/db/backends/schema.py b/django/db/backends/schema.py index ae80f60c30..a9601221bb 100644 --- a/django/db/backends/schema.py +++ b/django/db/backends/schema.py @@ -187,7 +187,7 @@ class BaseDatabaseSchemaEditor(object): } ) # FK - if field.rel: + if field.rel and self.connection.features.supports_foreign_keys: to_table = field.rel.to._meta.db_table to_column = field.rel.to._meta.get_field(field.rel.field_name).column self.deferred_sql.append( @@ -311,7 +311,7 @@ class BaseDatabaseSchemaEditor(object): } } # Add any FK constraints later - if field.rel: + if field.rel and self.connection.features.supports_foreign_keys: to_table = field.rel.to._meta.db_table to_column = field.rel.to._meta.get_field(field.rel.field_name).column self.deferred_sql.append( diff --git a/django/db/backends/sqlite3/base.py b/django/db/backends/sqlite3/base.py index d0a6fda78e..45e7264e5c 100644 --- a/django/db/backends/sqlite3/base.py +++ b/django/db/backends/sqlite3/base.py @@ -96,6 +96,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_mixed_date_datetime_comparisons = False has_bulk_insert = True can_combine_inserts_with_and_without_auto_increment_pk = False + supports_foreign_keys = False @cached_property def supports_stddev(self): diff --git a/django/db/backends/sqlite3/introspection.py b/django/db/backends/sqlite3/introspection.py index 1df4c18c1c..62c53e075a 100644 --- a/django/db/backends/sqlite3/introspection.py +++ b/django/db/backends/sqlite3/introspection.py @@ -154,7 +154,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): if len(info) != 1: continue name = info[0][2] # seqno, cid, name - indexes[name] = {'primary_key': False, + indexes[name] = {'primary_key': indexes.get(name, {}).get("primary_key", False), 'unique': unique} return indexes @@ -182,3 +182,37 @@ class DatabaseIntrospection(BaseDatabaseIntrospection): 'null_ok': not field[3], 'pk': field[5] # undocumented } for field in cursor.fetchall()] + + def get_constraints(self, cursor, table_name): + """ + Retrieves any constraints or keys (unique, pk, fk, check, index) across one or more columns. + """ + constraints = {} + # Get the index info + cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)) + for number, index, unique in cursor.fetchall(): + # Get the index info for that index + cursor.execute('PRAGMA index_info(%s)' % self.connection.ops.quote_name(index)) + for index_rank, column_rank, column in cursor.fetchall(): + if index not in constraints: + constraints[index] = { + "columns": set(), + "primary_key": False, + "unique": bool(unique), + "foreign_key": False, + "check": False, + "index": True, + } + constraints[index]['columns'].add(column) + # Get the PK + pk_column = self.get_primary_key_column(cursor, table_name) + if pk_column: + constraints["__primary__"] = { + "columns": set([pk_column]), + "primary_key": True, + "unique": False, # It's not actually a unique constraint + "foreign_key": False, + "check": False, + "index": False, + } + return constraints diff --git a/django/db/backends/sqlite3/schema.py b/django/db/backends/sqlite3/schema.py index bfd943c6fb..7938ad79cf 100644 --- a/django/db/backends/sqlite3/schema.py +++ b/django/db/backends/sqlite3/schema.py @@ -1,6 +1,116 @@ from django.db.backends.schema import BaseDatabaseSchemaEditor +from django.db.models.loading import cache +from django.db.models.fields.related import ManyToManyField class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_delete_table = "DROP TABLE %(table)s" + + def _remake_table(self, model, create_fields=[], delete_fields=[], alter_fields=[], rename_fields=[], override_uniques=None): + "Shortcut to transform a model from old_model into new_model" + # Work out the new fields dict / mapping + body = dict((f.name, f) for f in model._meta.local_fields) + mapping = dict((f.column, f.column) for f in model._meta.local_fields) + # If any of the new or altered fields is introducing a new PK, + # remove the old one + restore_pk_field = None + if any(f.primary_key for f in create_fields) or any(n.primary_key for o, n in alter_fields): + for name, field in list(body.items()): + if field.primary_key: + field.primary_key = False + restore_pk_field = field + if field.auto_created: + del body[name] + del mapping[field.column] + # Add in any created fields + for field in create_fields: + body[field.name] = field + # Add in any altered fields + for (old_field, new_field) in alter_fields: + del body[old_field.name] + del mapping[old_field.column] + body[new_field.name] = new_field + mapping[new_field.column] = old_field.column + # Remove any deleted fields + for field in delete_fields: + del body[field.name] + del mapping[field.column] + # Construct a new model for the new state + meta_contents = { + 'app_label': model._meta.app_label, + 'db_table': model._meta.db_table + "__new", + 'unique_together': model._meta.unique_together if override_uniques is None else override_uniques, + } + meta = type("Meta", tuple(), meta_contents) + body['Meta'] = meta + body['__module__'] = "__fake__" + with cache.temporary_state(): + del cache.app_models[model._meta.app_label][model._meta.object_name.lower()] + temp_model = type(model._meta.object_name, model.__bases__, body) + # Create a new table with that format + self.create_model(temp_model) + # Copy data from the old table + field_maps = list(mapping.items()) + self.execute("INSERT INTO %s (%s) SELECT %s FROM %s;" % ( + self.quote_name(temp_model._meta.db_table), + ', '.join([x for x, y in field_maps]), + ', '.join([y for x, y in field_maps]), + self.quote_name(model._meta.db_table), + )) + # Delete the old table + self.delete_model(model) + # Rename the new to the old + self.alter_db_table(model, temp_model._meta.db_table, model._meta.db_table) + # Run deferred SQL on correct table + for sql in self.deferred_sql: + self.execute(sql.replace(temp_model._meta.db_table, model._meta.db_table)) + self.deferred_sql = [] + # Fix any PK-removed field + if restore_pk_field: + restore_pk_field.primary_key = True + + def create_field(self, model, field): + """ + Creates a field on a model. + Usually involves adding a column, but may involve adding a + table instead (for M2M fields) + """ + # Special-case implicit M2M tables + if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: + return self.create_model(field.rel.through) + # Detect bad field combinations + if (not field.null and + (not field.has_default() or field.get_default() is None) and + not field.empty_strings_allowed): + raise ValueError("You cannot add a null=False column without a default value on SQLite.") + self._remake_table(model, create_fields=[field]) + + def delete_field(self, model, field): + """ + Removes a field from a model. Usually involves deleting a column, + but for M2Ms may involve deleting a table. + """ + # Special-case implicit M2M tables + if isinstance(field, ManyToManyField) and field.rel.through._meta.auto_created: + return self.delete_model(field.rel.through) + # For everything else, remake. + self._remake_table(model, delete_fields=[field]) + + def alter_field(self, model, old_field, new_field, strict=False): + # Ensure this field is even column-based + old_type = old_field.db_type(connection=self.connection) + new_type = self._type_for_alter(new_field) + if old_type is None and new_type is None: + # TODO: Handle M2M fields being repointed + return + elif old_type is None or new_type is None: + raise ValueError("Cannot alter field %s into %s - they are not compatible types" % ( + old_field, + new_field, + )) + # Alter by remaking table + self._remake_table(model, alter_fields=[(old_field, new_field)]) + + def alter_unique_together(self, model, old_unique_together, new_unique_together): + self._remake_table(model, override_uniques=new_unique_together) diff --git a/django/db/models/loading.py b/django/db/models/loading.py index 0ed6caffa4..e0d943853b 100644 --- a/django/db/models/loading.py +++ b/django/db/models/loading.py @@ -265,6 +265,10 @@ class AppCache(object): self.app_models = state['app_models'] self.app_errors = state['app_errors'] + def temporary_state(self): + "Returns a context manager that restores the state on exit" + return StateContextManager(self) + def unregister_all(self): """ Wipes the AppCache clean of all registered models. @@ -275,6 +279,23 @@ class AppCache(object): self.app_models = SortedDict() self.app_errors = {} + +class StateContextManager(object): + """ + Context manager for locking cache state. + Useful for making temporary models you don't want to stay in the cache. + """ + + def __init__(self, cache): + self.cache = cache + + def __enter__(self): + self.state = self.cache.save_state() + + def __exit__(self, type, value, traceback): + self.cache.restore_state(self.state) + + cache = AppCache() # These methods were always module level, so are kept that way for backwards diff --git a/tests/modeltests/schema/models.py b/tests/modeltests/schema/models.py index 9d0a8a2074..b18d2a9c16 100644 --- a/tests/modeltests/schema/models.py +++ b/tests/modeltests/schema/models.py @@ -29,6 +29,17 @@ class Book(models.Model): managed = False +class BookWithSlug(models.Model): + author = models.ForeignKey(Author) + title = models.CharField(max_length=100, db_index=True) + pub_date = models.DateTimeField() + slug = models.CharField(max_length=20, unique=True) + + class Meta: + managed = False + db_table = "schema_book" + + class Tag(models.Model): title = models.CharField(max_length=255) slug = models.SlugField(unique=True) diff --git a/tests/modeltests/schema/tests.py b/tests/modeltests/schema/tests.py index db374dc7ad..c76ca8ca16 100644 --- a/tests/modeltests/schema/tests.py +++ b/tests/modeltests/schema/tests.py @@ -2,11 +2,12 @@ from __future__ import absolute_import import copy import datetime from django.test import TestCase +from django.utils.unittest import skipUnless from django.db import connection, DatabaseError, IntegrityError from django.db.models.fields import IntegerField, TextField, CharField, SlugField from django.db.models.fields.related import ManyToManyField from django.db.models.loading import cache -from .models import Author, Book, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest +from .models import Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest class SchemaTests(TestCase): @@ -18,7 +19,7 @@ class SchemaTests(TestCase): as the code it is testing. """ - models = [Author, Book, AuthorWithM2M, Tag, UniqueTest] + models = [Author, Book, BookWithSlug, AuthorWithM2M, Tag, TagUniqueRename, UniqueTest] # Utility functions @@ -70,13 +71,21 @@ class SchemaTests(TestCase): def column_classes(self, model): cursor = connection.cursor() - return dict( + columns = dict( (d[0], (connection.introspection.get_field_type(d[1], d), d)) for d in connection.introspection.get_table_description( cursor, model._meta.db_table, ) ) + # SQLite has a different format for field_type + for name, (type, desc) in columns.items(): + if isinstance(type, tuple): + columns[name] = (type[0], desc) + # SQLite also doesn't error properly + if not columns: + raise DatabaseError("Table does not exist (empty pragma)") + return columns # Tests @@ -104,6 +113,7 @@ class SchemaTests(TestCase): lambda: list(Author.objects.all()), ) + @skipUnless(connection.features.supports_foreign_keys, "No FK support") def test_creation_fk(self): "Tests that creating tables out of FK order works" # Create the table @@ -449,13 +459,11 @@ class SchemaTests(TestCase): connection.introspection.get_indexes(connection.cursor(), Book._meta.db_table), ) # Add a unique column, verify that creates an implicit index - new_field = CharField(max_length=20, unique=True) - new_field.set_attributes_from_name("slug") editor = connection.schema_editor() editor.start() editor.create_field( Book, - new_field, + BookWithSlug._meta.get_field_by_name("slug")[0], ) editor.commit() self.assertIn( @@ -468,8 +476,8 @@ class SchemaTests(TestCase): editor = connection.schema_editor() editor.start() editor.alter_field( - Book, - new_field, + BookWithSlug, + BookWithSlug._meta.get_field_by_name("slug")[0], new_field2, strict = True, )