Fixed #26552 -- Deferred constraint checks when reloading the database with data for tests.

deserialize_db_from_string() loads the full serialized database
contents, which might contain forward references and cycles. That
caused IntegrityError because constraints were checked immediately.

Now, it loads data in a transaction with constraint checks deferred
until the end of the transaction.
This commit is contained in:
Matthijs Kooijman 2019-12-02 00:42:06 +01:00 committed by Mariusz Felisiak
parent b330b918e9
commit 98f23a8af0
3 changed files with 40 additions and 2 deletions

View File

@ -6,6 +6,7 @@ from django.apps import apps
from django.conf import settings from django.conf import settings
from django.core import serializers from django.core import serializers
from django.db import router from django.db import router
from django.db.transaction import atomic
# The prefix to put on the default database name when creating # The prefix to put on the default database name when creating
# the test database. # the test database.
@ -126,8 +127,16 @@ class BaseDatabaseCreation:
the serialize_db_to_string() method. the serialize_db_to_string() method.
""" """
data = StringIO(data) data = StringIO(data)
for obj in serializers.deserialize("json", data, using=self.connection.alias): # Load data in a transaction to handle forward references and cycles.
obj.save() with atomic(using=self.connection.alias):
# Disable constraint checks, because some databases (MySQL) doesn't
# support deferred checks.
with self.connection.constraint_checks_disabled():
for obj in serializers.deserialize('json', data, using=self.connection.alias):
obj.save()
# Manually check for any invalid keys that might have been added,
# because constraint checks were disabled.
self.connection.check_constraints()
def _get_database_display_str(self, verbosity, database_name): def _get_database_display_str(self, verbosity, database_name):
""" """

View File

@ -7,6 +7,8 @@ from django.db.backends.base.creation import (
) )
from django.test import SimpleTestCase from django.test import SimpleTestCase
from ..models import Object, ObjectReference
def get_connection_copy(): def get_connection_copy():
# Get a copy of the default connection. (Can't use django.db.connection # Get a copy of the default connection. (Can't use django.db.connection
@ -73,3 +75,29 @@ class TestDbCreationTests(SimpleTestCase):
finally: finally:
with mock.patch.object(creation, '_destroy_test_db'): with mock.patch.object(creation, '_destroy_test_db'):
creation.destroy_test_db(old_database_name, verbosity=0) creation.destroy_test_db(old_database_name, verbosity=0)
class TestDeserializeDbFromString(SimpleTestCase):
databases = {'default'}
def test_circular_reference(self):
# deserialize_db_from_string() handles circular references.
data = """
[
{
"model": "backends.object",
"pk": 1,
"fields": {"obj_ref": 1, "related_objects": []}
},
{
"model": "backends.objectreference",
"pk": 1,
"fields": {"obj": 1}
}
]
"""
connection.creation.deserialize_db_from_string(data)
obj = Object.objects.get()
obj_ref = ObjectReference.objects.get()
self.assertEqual(obj.obj_ref, obj_ref)
self.assertEqual(obj_ref.obj, obj)

View File

@ -89,6 +89,7 @@ class Item(models.Model):
class Object(models.Model): class Object(models.Model):
related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False) related_objects = models.ManyToManyField("self", db_constraint=False, symmetrical=False)
obj_ref = models.ForeignKey('ObjectReference', models.CASCADE, null=True)
def __str__(self): def __str__(self):
return str(self.id) return str(self.id)