Fixed #3615: Added support for loading fixtures with forward references on database backends (such as MySQL/InnoDB) that do not support deferred constraint checking. Many thanks to jsdalton for coming up with a clever solution to this long-standing issue, and to jacob, ramiro, graham_king, and russellm for review/testing. (Apologies if I missed anyone else who helped here.)

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16590 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Karen Tracey 2011-08-07 00:43:26 +00:00
parent e3c89346d2
commit be87f0b0ec
16 changed files with 356 additions and 23 deletions

View File

@ -1,3 +1,7 @@
# This is necessary in Python 2.5 to enable the with statement, in 2.6
# and up it is no longer necessary.
from __future__ import with_statement
import sys import sys
import os import os
import gzip import gzip
@ -166,12 +170,20 @@ class Command(BaseCommand):
(format, fixture_name, humanize(fixture_dir))) (format, fixture_name, humanize(fixture_dir)))
try: try:
objects = serializers.deserialize(format, fixture, using=using) objects = serializers.deserialize(format, fixture, using=using)
with connection.constraint_checks_disabled():
for obj in objects: for obj in objects:
objects_in_fixture += 1 objects_in_fixture += 1
if router.allow_syncdb(using, obj.object.__class__): if router.allow_syncdb(using, obj.object.__class__):
loaded_objects_in_fixture += 1 loaded_objects_in_fixture += 1
models.add(obj.object.__class__) models.add(obj.object.__class__)
obj.save(using=using) obj.save(using=using)
# Since we disabled constraint checks, we must manually check for
# any invalid keys that might have been added
table_names = [model._meta.db_table for model in models]
connection.check_constraints(table_names=table_names)
loaded_object_count += loaded_objects_in_fixture loaded_object_count += loaded_objects_in_fixture
fixture_object_count += objects_in_fixture fixture_object_count += objects_in_fixture
label_found = True label_found = True

View File

@ -3,6 +3,7 @@ try:
except ImportError: except ImportError:
import dummy_thread as thread import dummy_thread as thread
from threading import local from threading import local
from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
@ -238,6 +239,35 @@ class BaseDatabaseWrapper(local):
if self.savepoint_state: if self.savepoint_state:
self._savepoint_commit(sid) self._savepoint_commit(sid)
@contextmanager
def constraint_checks_disabled(self):
disabled = self.disable_constraint_checking()
try:
yield
finally:
if disabled:
self.enable_constraint_checking()
def disable_constraint_checking(self):
"""
Backends can implement as needed to temporarily disable foreign key constraint
checking.
"""
pass
def enable_constraint_checking(self):
"""
Backends can implement as needed to re-enable foreign key constraint checking.
"""
pass
def check_constraints(self, table_names=None):
"""
Backends can override this method if they can apply constraint checking (e.g. via "SET CONSTRAINTS
ALL IMMEDIATE"). Should raise an IntegrityError if any invalid foreign key references are encountered.
"""
pass
def close(self): def close(self):
if self.connection is not None: if self.connection is not None:
self.connection.close() self.connection.close()
@ -869,6 +899,19 @@ class BaseDatabaseIntrospection(object):
return sequence_list return sequence_list
def get_key_columns(self, cursor, table_name):
"""
Backends can override this to return a list of (column_name, referenced_table_name,
referenced_column_name) for all key columns in given table.
"""
raise NotImplementedError
def get_primary_key_column(self, cursor, table_name):
"""
Backends can override this to return the column name of the primary key for the given table.
"""
raise NotImplementedError
class BaseDatabaseClient(object): class BaseDatabaseClient(object):
""" """
This class encapsulates all backend-specific methods for opening a This class encapsulates all backend-specific methods for opening a

View File

@ -34,6 +34,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
get_table_description = complain get_table_description = complain
get_relations = complain get_relations = complain
get_indexes = complain get_indexes = complain
get_key_columns = complain
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
operators = {} operators = {}

View File

@ -349,3 +349,52 @@ class DatabaseWrapper(BaseDatabaseWrapper):
raise Exception('Unable to determine MySQL version from version string %r' % self.connection.get_server_info()) raise Exception('Unable to determine MySQL version from version string %r' % self.connection.get_server_info())
self.server_version = tuple([int(x) for x in m.groups()]) self.server_version = tuple([int(x) for x in m.groups()])
return self.server_version return self.server_version
def disable_constraint_checking(self):
"""
Disables foreign key checks, primarily for use in adding rows with forward references. Always returns True,
to indicate constraint checks need to be re-enabled.
"""
self.cursor().execute('SET foreign_key_checks=0')
return True
def enable_constraint_checking(self):
"""
Re-enable foreign key checks after they have been disabled.
"""
self.cursor().execute('SET foreign_key_checks=1')
def check_constraints(self, table_names=None):
"""
Checks each table name in table-names for rows with invalid foreign key references. This method is
intended to be used in conjunction with `disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint checks were off.
Raises an IntegrityError on the first invalid foreign key reference encountered (if any) and provides
detailed information about the invalid reference in the error message.
Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS
ALL IMMEDIATE")
"""
cursor = self.cursor()
if table_names is None:
table_names = self.introspection.get_table_list(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute("""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL"""
% (primary_key_column_name, column_name, table_name, referenced_table_name,
column_name, referenced_column_name, column_name, referenced_column_name))
for bad_row in cursor.fetchall():
raise utils.IntegrityError("The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
% (table_name, bad_row[0],
table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name))

View File

@ -51,10 +51,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
representing all relationships to the given table. Indexes are 0-based. representing all relationships to the given table. Indexes are 0-based.
""" """
my_field_dict = self._name_to_index(cursor, table_name) my_field_dict = self._name_to_index(cursor, table_name)
constraints = [] constraints = self.get_key_columns(cursor, table_name)
relations = {} relations = {}
for my_fieldname, other_table, other_field in constraints:
other_field_index = self._name_to_index(cursor, other_table)[other_field]
my_field_index = my_field_dict[my_fieldname]
relations[my_field_index] = (other_field_index, other_table)
return relations
def get_key_columns(self, cursor, table_name):
"""
Returns a list of (column_name, referenced_table_name, referenced_column_name) for all
key columns in given table.
"""
key_columns = []
try: try:
# This should work for MySQL 5.0.
cursor.execute(""" cursor.execute("""
SELECT column_name, referenced_table_name, referenced_column_name SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage FROM information_schema.key_column_usage
@ -62,7 +73,7 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
AND table_schema = DATABASE() AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL""", [table_name]) AND referenced_column_name IS NOT NULL""", [table_name])
constraints.extend(cursor.fetchall()) key_columns.extend(cursor.fetchall())
except (ProgrammingError, OperationalError): except (ProgrammingError, OperationalError):
# Fall back to "SHOW CREATE TABLE", for previous MySQL versions. # Fall back to "SHOW CREATE TABLE", for previous MySQL versions.
# Go through all constraints and save the equal matches. # Go through all constraints and save the equal matches.
@ -74,14 +85,17 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
if match == None: if match == None:
break break
pos = match.end() pos = match.end()
constraints.append(match.groups()) key_columns.append(match.groups())
return key_columns
for my_fieldname, other_table, other_field in constraints: def get_primary_key_column(self, cursor, table_name):
other_field_index = self._name_to_index(cursor, other_table)[other_field] """
my_field_index = my_field_dict[my_fieldname] Returns the name of the primary key column for the given table
relations[my_field_index] = (other_field_index, other_table) """
for column in self.get_indexes(cursor, table_name).iteritems():
return relations if column[1]['primary_key']:
return column[0]
return None
def get_indexes(self, cursor, table_name): def get_indexes(self, cursor, table_name):
""" """

View File

@ -428,6 +428,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.introspection = DatabaseIntrospection(self) self.introspection = DatabaseIntrospection(self)
self.validation = BaseDatabaseValidation(self) self.validation = BaseDatabaseValidation(self)
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def _valid_connection(self): def _valid_connection(self):
return self.connection is not None return self.connection is not None

View File

@ -106,6 +106,14 @@ class DatabaseWrapper(BaseDatabaseWrapper):
self.validation = BaseDatabaseValidation(self) self.validation = BaseDatabaseValidation(self)
self._pg_version = None self._pg_version = None
def check_constraints(self, table_names=None):
"""
To check constraints, we set constraints to immediate. Then, when, we're done we must ensure they
are returned to deferred.
"""
self.cursor().execute('SET CONSTRAINTS ALL IMMEDIATE')
self.cursor().execute('SET CONSTRAINTS ALL DEFERRED')
def _get_pg_version(self): def _get_pg_version(self):
if self._pg_version is None: if self._pg_version is None:
self._pg_version = get_version(self.connection) self._pg_version = get_version(self.connection)

View File

@ -206,6 +206,40 @@ class DatabaseWrapper(BaseDatabaseWrapper):
connection_created.send(sender=self.__class__, connection=self) connection_created.send(sender=self.__class__, connection=self)
return self.connection.cursor(factory=SQLiteCursorWrapper) return self.connection.cursor(factory=SQLiteCursorWrapper)
def check_constraints(self, table_names=None):
"""
Checks each table name in table-names for rows with invalid foreign key references. This method is
intended to be used in conjunction with `disable_constraint_checking()` and `enable_constraint_checking()`, to
determine if rows with invalid references were entered while constraint checks were off.
Raises an IntegrityError on the first invalid foreign key reference encountered (if any) and provides
detailed information about the invalid reference in the error message.
Backends can override this method if they can more directly apply constraint checking (e.g. via "SET CONSTRAINTS
ALL IMMEDIATE")
"""
cursor = self.cursor()
if table_names is None:
table_names = self.introspection.get_table_list(cursor)
for table_name in table_names:
primary_key_column_name = self.introspection.get_primary_key_column(cursor, table_name)
if not primary_key_column_name:
continue
key_columns = self.introspection.get_key_columns(cursor, table_name)
for column_name, referenced_table_name, referenced_column_name in key_columns:
cursor.execute("""
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
LEFT JOIN `%s` as REFERRED
ON (REFERRING.`%s` = REFERRED.`%s`)
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL"""
% (primary_key_column_name, column_name, table_name, referenced_table_name,
column_name, referenced_column_name, column_name, referenced_column_name))
for bad_row in cursor.fetchall():
raise utils.IntegrityError("The row in table '%s' with primary key '%s' has an invalid "
"foreign key: %s.%s contains a value '%s' that does not have a corresponding value in %s.%s."
% (table_name, bad_row[0], table_name, column_name, bad_row[1],
referenced_table_name, referenced_column_name))
def close(self): def close(self):
# If database is in memory, closing the connection destroys the # If database is in memory, closing the connection destroys the
# database. To prevent accidental data loss, ignore close requests on # database. To prevent accidental data loss, ignore close requests on

View File

@ -103,6 +103,35 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
return relations return relations
def get_key_columns(self, cursor, table_name):
"""
Returns a list of (column_name, referenced_table_name, referenced_column_name) for all
key columns in given table.
"""
key_columns = []
# Schema for this table
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
results = cursor.fetchone()[0].strip()
results = results[results.index('(')+1:results.rindex(')')]
# Walk through and look for references to other tables. SQLite doesn't
# really have enforced references, but since it echoes out the SQL used
# to create the table we can look for REFERENCES statements used there.
for field_index, field_desc in enumerate(results.split(',')):
field_desc = field_desc.strip()
if field_desc.startswith("UNIQUE"):
continue
m = re.search('"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I)
if not m:
continue
# This will append (column_name, referenced_table_name, referenced_column_name) to key_columns
key_columns.append(tuple([s.strip('"') for s in m.groups()]))
return key_columns
def get_indexes(self, cursor, table_name): def get_indexes(self, cursor, table_name):
""" """
Returns a dictionary of fieldname -> infodict for the given table, Returns a dictionary of fieldname -> infodict for the given table,
@ -128,6 +157,21 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
indexes[name]['unique'] = True indexes[name]['unique'] = True
return indexes return indexes
def get_primary_key_column(self, cursor, table_name):
"""
Get the column name of the primary key for the given table.
"""
# Don't use PRAGMA because that causes issues with some transactions
cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s", [table_name, "table"])
results = cursor.fetchone()[0].strip()
results = results[results.index('(')+1:results.rindex(')')]
for field_desc in results.split(','):
field_desc = field_desc.strip()
m = re.search('"(.*)".*PRIMARY KEY$', field_desc)
if m:
return m.groups()[0]
return None
def _table_info(self, cursor, name): def _table_info(self, cursor, name):
cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name)) cursor.execute('PRAGMA table_info(%s)' % self.connection.ops.quote_name(name))
# cid, name, type, notnull, dflt_value, pk # cid, name, type, notnull, dflt_value, pk

View File

@ -142,6 +142,18 @@ currently the only engine that supports full-text indexing and searching.
The InnoDB_ engine is fully transactional and supports foreign key references The InnoDB_ engine is fully transactional and supports foreign key references
and is probably the best choice at this point in time. and is probably the best choice at this point in time.
.. versionchanged:: 1.4
In previous versions of Django, fixtures with forward references (i.e.
relations to rows that have not yet been inserted into the database) would fail
to load when using the InnoDB storage engine. This was due to the fact that InnoDB
deviates from the SQL standard by checking foreign key constraints immediately
instead of deferring the check until the transaction is committed. This
problem has been resolved in Django 1.4. Fixture data is now loaded with foreign key
checks turned off; foreign key checks are then re-enabled when the data has
finished loading, at which point the entire table is checked for invalid foreign
key references and an `IntegrityError` is raised if any are found.
.. _storage engines: http://dev.mysql.com/doc/refman/5.5/en/storage-engines.html .. _storage engines: http://dev.mysql.com/doc/refman/5.5/en/storage-engines.html
.. _MyISAM: http://dev.mysql.com/doc/refman/5.5/en/myisam-storage-engine.html .. _MyISAM: http://dev.mysql.com/doc/refman/5.5/en/myisam-storage-engine.html
.. _InnoDB: http://dev.mysql.com/doc/refman/5.5/en/innodb.html .. _InnoDB: http://dev.mysql.com/doc/refman/5.5/en/innodb.html

View File

@ -235,6 +235,9 @@ Django 1.4 also includes several smaller improvements worth noting:
to delete all files at the destination before copying or linking the static to delete all files at the destination before copying or linking the static
files. files.
* It is now possible to load fixtures containing forward references when using
MySQL with the InnoDB database engine.
.. _backwards-incompatible-changes-1.4: .. _backwards-incompatible-changes-1.4:
Backwards incompatible changes in 1.4 Backwards incompatible changes in 1.4

View File

@ -1,3 +1,7 @@
# This is necessary in Python 2.5 to enable the with statement, in 2.6
# and up it is no longer necessary.
from __future__ import with_statement
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from datetime import datetime from datetime import datetime
from StringIO import StringIO from StringIO import StringIO
@ -5,7 +9,7 @@ from xml.dom import minidom
from django.conf import settings from django.conf import settings
from django.core import serializers from django.core import serializers
from django.db import transaction from django.db import transaction, connection
from django.test import TestCase, TransactionTestCase, Approximate from django.test import TestCase, TransactionTestCase, Approximate
from django.utils import simplejson, unittest from django.utils import simplejson, unittest
@ -252,6 +256,7 @@ class SerializersTransactionTestBase(object):
transaction.enter_transaction_management() transaction.enter_transaction_management()
transaction.managed(True) transaction.managed(True)
objs = serializers.deserialize(self.serializer_name, self.fwd_ref_str) objs = serializers.deserialize(self.serializer_name, self.fwd_ref_str)
with connection.constraint_checks_disabled():
for obj in objs: for obj in objs:
obj.save() obj.save()
transaction.commit() transaction.commit()

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Unit and doctests for specific database backends. # Unit and doctests for specific database backends.
from __future__ import with_statement
import datetime import datetime
from django.conf import settings from django.conf import settings
from django.core.management.color import no_style from django.core.management.color import no_style
from django.db import backend, connection, connections, DEFAULT_DB_ALIAS, IntegrityError from django.db import backend, connection, connections, DEFAULT_DB_ALIAS, IntegrityError, transaction
from django.db.backends.signals import connection_created from django.db.backends.signals import connection_created
from django.db.backends.postgresql_psycopg2 import version as pg_version from django.db.backends.postgresql_psycopg2 import version as pg_version
from django.test import TestCase, skipUnlessDBFeature, TransactionTestCase from django.test import TestCase, skipUnlessDBFeature, TransactionTestCase
@ -328,7 +329,8 @@ class FkConstraintsTests(TransactionTestCase):
try: try:
a.save() a.save()
except IntegrityError: except IntegrityError:
pass return
self.skipTest("This backend does not support integrity checks.")
def test_integrity_checks_on_update(self): def test_integrity_checks_on_update(self):
""" """
@ -343,4 +345,60 @@ class FkConstraintsTests(TransactionTestCase):
try: try:
a.save() a.save()
except IntegrityError: except IntegrityError:
pass return
self.skipTest("This backend does not support integrity checks.")
def test_disable_constraint_checks_manually(self):
"""
When constraint checks are disabled, should be able to write bad data without IntegrityErrors.
"""
with transaction.commit_manually():
# Create an Article.
models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
# Retrive it from the DB
a = models.Article.objects.get(headline="Test article")
a.reporter_id = 30
try:
connection.disable_constraint_checking()
a.save()
connection.enable_constraint_checking()
except IntegrityError:
self.fail("IntegrityError should not have occurred.")
finally:
transaction.rollback()
def test_disable_constraint_checks_context_manager(self):
"""
When constraint checks are disabled (using context manager), should be able to write bad data without IntegrityErrors.
"""
with transaction.commit_manually():
# Create an Article.
models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
# Retrive it from the DB
a = models.Article.objects.get(headline="Test article")
a.reporter_id = 30
try:
with connection.constraint_checks_disabled():
a.save()
except IntegrityError:
self.fail("IntegrityError should not have occurred.")
finally:
transaction.rollback()
def test_check_constraints(self):
"""
Constraint checks should raise an IntegrityError when bad data is in the DB.
"""
with transaction.commit_manually():
# Create an Article.
models.Article.objects.create(headline="Test article", pub_date=datetime.datetime(2010, 9, 4), reporter=self.r)
# Retrive it from the DB
a = models.Article.objects.get(headline="Test article")
a.reporter_id = 30
try:
with connection.constraint_checks_disabled():
a.save()
with self.assertRaises(IntegrityError):
connection.check_constraints()
finally:
transaction.rollback()

View File

@ -362,6 +362,35 @@ class TestFixtures(TestCase):
% widget.pk % widget.pk
) )
def test_loaddata_works_when_fixture_has_forward_refs(self):
"""
Regression for #3615 - Forward references cause fixtures not to load in MySQL (InnoDB)
"""
management.call_command(
'loaddata',
'forward_ref.json',
verbosity=0,
commit=False
)
self.assertEqual(Book.objects.all()[0].id, 1)
self.assertEqual(Person.objects.all()[0].id, 4)
def test_loaddata_raises_error_when_fixture_has_invalid_foreign_key(self):
"""
Regression for #3615 - Ensure data with nonexistent child key references raises error
"""
stderr = StringIO()
management.call_command(
'loaddata',
'forward_ref_bad_data.json',
verbosity=0,
commit=False,
stderr=stderr,
)
self.assertTrue(
stderr.getvalue().startswith('Problem installing fixture')
)
class NaturalKeyFixtureTests(TestCase): class NaturalKeyFixtureTests(TestCase):
def assertRaisesMessage(self, exc, msg, func, *args, **kwargs): def assertRaisesMessage(self, exc, msg, func, *args, **kwargs):

View File

@ -95,6 +95,16 @@ class IntrospectionTests(TestCase):
# That's {field_index: (field_index_other_table, other_table)} # That's {field_index: (field_index_other_table, other_table)}
self.assertEqual(relations, {3: (0, Reporter._meta.db_table)}) self.assertEqual(relations, {3: (0, Reporter._meta.db_table)})
def test_get_key_columns(self):
cursor = connection.cursor()
key_columns = connection.introspection.get_key_columns(cursor, Article._meta.db_table)
self.assertEqual(key_columns, [(u'reporter_id', Reporter._meta.db_table, u'id')])
def test_get_primary_key_column(self):
cursor = connection.cursor()
primary_key_column = connection.introspection.get_primary_key_column(cursor, Article._meta.db_table)
self.assertEqual(primary_key_column, u'id')
def test_get_indexes(self): def test_get_indexes(self):
cursor = connection.cursor() cursor = connection.cursor()
indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table) indexes = connection.introspection.get_indexes(cursor, Article._meta.db_table)

View File

@ -6,6 +6,8 @@ test case that is capable of testing the capabilities of
the serializers. This includes all valid data values, plus the serializers. This includes all valid data values, plus
forward, backwards and self references. forward, backwards and self references.
""" """
# This is necessary in Python 2.5 to enable the with statement, in 2.6
# and up it is no longer necessary.
from __future__ import with_statement from __future__ import with_statement
import datetime import datetime
@ -382,6 +384,7 @@ def serializerTest(format, self):
objects = [] objects = []
instance_count = {} instance_count = {}
for (func, pk, klass, datum) in test_data: for (func, pk, klass, datum) in test_data:
with connection.constraint_checks_disabled():
objects.extend(func[0](pk, klass, datum)) objects.extend(func[0](pk, klass, datum))
# Get a count of the number of objects created for each class # Get a count of the number of objects created for each class