Fixed #26291 -- Allowed loaddata to handle forward references in natural_key fixtures.

This commit is contained in:
Peter Inglesby 2018-07-13 22:54:47 +01:00 committed by Tim Graham
parent 8f75d21a2e
commit 312eb5cb11
12 changed files with 352 additions and 25 deletions

View File

@ -109,8 +109,11 @@ class Command(BaseCommand):
return
with connection.constraint_checks_disabled():
self.objs_with_deferred_fields = []
for fixture_label in fixture_labels:
self.load_label(fixture_label)
for obj in self.objs_with_deferred_fields:
obj.save_deferred_fields(using=self.using)
# Since we disabled constraint checks, we must manually check for
# any invalid keys that might have been added
@ -163,6 +166,7 @@ class Command(BaseCommand):
objects = serializers.deserialize(
ser_fmt, fixture, using=self.using, ignorenonexistent=self.ignore,
handle_forward_references=True,
)
for obj in objects:
@ -189,6 +193,8 @@ class Command(BaseCommand):
'error_msg': e,
},)
raise
if obj.deferred_fields:
self.objs_with_deferred_fields.append(obj)
if objects and show_progress:
self.stdout.write('') # add a newline after progress indicator
self.loaded_object_count += loaded_objects_in_fixture

View File

@ -3,8 +3,11 @@ Module for abstract serializer/unserializer base classes.
"""
from io import StringIO
from django.core.exceptions import ObjectDoesNotExist
from django.db import models
DEFER_FIELD = object()
class SerializerDoesNotExist(KeyError):
"""The requested serializer was not found."""
@ -201,9 +204,10 @@ class DeserializedObject:
(and not touch the many-to-many stuff.)
"""
def __init__(self, obj, m2m_data=None):
def __init__(self, obj, m2m_data=None, deferred_fields=None):
self.object = obj
self.m2m_data = m2m_data
self.deferred_fields = deferred_fields
def __repr__(self):
return "<%s: %s(pk=%s)>" % (
@ -225,6 +229,25 @@ class DeserializedObject:
# the m2m data twice.
self.m2m_data = None
def save_deferred_fields(self, using=None):
self.m2m_data = {}
for field, field_value in self.deferred_fields.items():
opts = self.object._meta
label = opts.app_label + '.' + opts.model_name
if isinstance(field.remote_field, models.ManyToManyRel):
try:
values = deserialize_m2m_values(field, field_value, using, handle_forward_references=False)
except M2MDeserializationError as e:
raise DeserializationError.WithData(e.original_exc, label, self.object.pk, e.pk)
self.m2m_data[field.name] = values
elif isinstance(field.remote_field, models.ManyToOneRel):
try:
value = deserialize_fk_value(field, field_value, using, handle_forward_references=False)
except Exception as e:
raise DeserializationError.WithData(e, label, self.object.pk, field_value)
setattr(self.object, field.attname, value)
self.save()
def build_instance(Model, data, db):
"""
@ -244,7 +267,7 @@ def build_instance(Model, data, db):
return obj
def deserialize_m2m_values(field, field_value, using):
def deserialize_m2m_values(field, field_value, using, handle_forward_references):
model = field.remote_field.model
if hasattr(model._default_manager, 'get_by_natural_key'):
def m2m_convert(value):
@ -262,10 +285,13 @@ def deserialize_m2m_values(field, field_value, using):
values.append(m2m_convert(pk))
return values
except Exception as e:
raise M2MDeserializationError(e, pk)
if isinstance(e, ObjectDoesNotExist) and handle_forward_references:
return DEFER_FIELD
else:
raise M2MDeserializationError(e, pk)
def deserialize_fk_value(field, field_value, using):
def deserialize_fk_value(field, field_value, using, handle_forward_references):
if field_value is None:
return None
model = field.remote_field.model
@ -273,7 +299,13 @@ def deserialize_fk_value(field, field_value, using):
field_name = field.remote_field.field_name
if (hasattr(default_manager, 'get_by_natural_key') and
hasattr(field_value, '__iter__') and not isinstance(field_value, str)):
obj = default_manager.db_manager(using).get_by_natural_key(*field_value)
try:
obj = default_manager.db_manager(using).get_by_natural_key(*field_value)
except ObjectDoesNotExist:
if handle_forward_references:
return DEFER_FIELD
else:
raise
value = getattr(obj, field_name)
# If this is a natural foreign key to an object that has a FK/O2O as
# the foreign key, use the FK value.

View File

@ -83,6 +83,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
It's expected that you pass the Python objects themselves (instead of a
stream or a string) to the constructor
"""
handle_forward_references = options.pop('handle_forward_references', False)
field_names_cache = {} # Model: <list of field_names>
for d in object_list:
@ -101,6 +102,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
except Exception as e:
raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), None)
m2m_data = {}
deferred_fields = {}
if Model not in field_names_cache:
field_names_cache[Model] = {f.name for f in Model._meta.get_fields()}
@ -118,17 +120,23 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
# Handle M2M relations
if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
try:
values = base.deserialize_m2m_values(field, field_value, using)
values = base.deserialize_m2m_values(field, field_value, using, handle_forward_references)
except base.M2MDeserializationError as e:
raise base.DeserializationError.WithData(e.original_exc, d['model'], d.get('pk'), e.pk)
m2m_data[field.name] = values
if values == base.DEFER_FIELD:
deferred_fields[field] = field_value
else:
m2m_data[field.name] = values
# Handle FK fields
elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
try:
value = base.deserialize_fk_value(field, field_value, using)
value = base.deserialize_fk_value(field, field_value, using, handle_forward_references)
except Exception as e:
raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
data[field.attname] = value
if value == base.DEFER_FIELD:
deferred_fields[field] = field_value
else:
data[field.attname] = value
# Handle all other fields
else:
try:
@ -137,7 +145,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False
raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value)
obj = base.build_instance(Model, data, using)
yield base.DeserializedObject(obj, m2m_data)
yield base.DeserializedObject(obj, m2m_data, deferred_fields)
def _get_model(model_identifier):

View File

@ -8,6 +8,7 @@ from xml.sax.expatreader import ExpatParser as _ExpatParser
from django.apps import apps
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.core.serializers import base
from django.db import DEFAULT_DB_ALIAS, models
from django.utils.xmlutils import (
@ -151,6 +152,7 @@ class Deserializer(base.Deserializer):
def __init__(self, stream_or_string, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options):
super().__init__(stream_or_string, **options)
self.handle_forward_references = options.pop('handle_forward_references', False)
self.event_stream = pulldom.parse(self.stream, self._make_parser())
self.db = using
self.ignore = ignorenonexistent
@ -181,6 +183,7 @@ class Deserializer(base.Deserializer):
# Also start building a dict of m2m data (this is saved as
# {m2m_accessor_attribute : [list_of_related_objects]})
m2m_data = {}
deferred_fields = {}
field_names = {f.name for f in Model._meta.get_fields()}
# Deserialize each field.
@ -200,9 +203,26 @@ class Deserializer(base.Deserializer):
# As is usually the case, relation fields get the special treatment.
if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
m2m_data[field.name] = self._handle_m2m_field_node(field_node, field)
value = self._handle_m2m_field_node(field_node, field)
if value == base.DEFER_FIELD:
deferred_fields[field] = [
[
getInnerText(nat_node).strip()
for nat_node in obj_node.getElementsByTagName('natural')
]
for obj_node in field_node.getElementsByTagName('object')
]
else:
m2m_data[field.name] = value
elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
data[field.attname] = self._handle_fk_field_node(field_node, field)
value = self._handle_fk_field_node(field_node, field)
if value == base.DEFER_FIELD:
deferred_fields[field] = [
getInnerText(k).strip()
for k in field_node.getElementsByTagName('natural')
]
else:
data[field.attname] = value
else:
if field_node.getElementsByTagName('None'):
value = None
@ -213,7 +233,7 @@ class Deserializer(base.Deserializer):
obj = base.build_instance(Model, data, self.db)
# Return a DeserializedObject so that the m2m data has a place to live.
return base.DeserializedObject(obj, m2m_data)
return base.DeserializedObject(obj, m2m_data, deferred_fields)
def _handle_fk_field_node(self, node, field):
"""
@ -229,7 +249,13 @@ class Deserializer(base.Deserializer):
if keys:
# If there are 'natural' subelements, it must be a natural key
field_value = [getInnerText(k).strip() for k in keys]
obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value)
try:
obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value)
except ObjectDoesNotExist:
if self.handle_forward_references:
return base.DEFER_FIELD
else:
raise
obj_pk = getattr(obj, field.remote_field.field_name)
# If this is a natural foreign key to an object that
# has a FK/O2O as the foreign key, use the FK value
@ -264,7 +290,17 @@ class Deserializer(base.Deserializer):
else:
def m2m_convert(n):
return model._meta.pk.to_python(n.getAttribute('pk'))
return [m2m_convert(c) for c in node.getElementsByTagName("object")]
values = []
try:
for c in node.getElementsByTagName('object'):
values.append(m2m_convert(c))
except Exception as e:
if isinstance(e, ObjectDoesNotExist) and self.handle_forward_references:
return base.DEFER_FIELD
else:
raise base.M2MDeserializationError(e, c)
else:
return values
def _get_model_from_node(self, node, attr):
"""

View File

@ -184,7 +184,10 @@ Requests and Responses
Serialization
~~~~~~~~~~~~~
* ...
* You can now deserialize data using natural keys containing :ref:`forward
references <natural-keys-and-forward-references>` by passing
``handle_forward_references=True`` to ``serializers.deserialize()``.
Additionally, :djadmin:`loaddata` handles forward references automatically.
Signals
~~~~~~~

View File

@ -514,17 +514,68 @@ command line flags to generate natural keys.
natural keys during serialization, but *not* be able to load those
key values, just don't define the ``get_by_natural_key()`` method.
.. _natural-keys-and-forward-references:
Natural keys and forward references
-----------------------------------
.. versionadded:: 2.2
Sometimes when you use :ref:`natural foreign keys
<topics-serialization-natural-keys>` you'll need to deserialize data where
an object has a foreign key referencing another object that hasn't yet been
deserialized. This is called a "forward reference".
For instance, suppose you have the following objects in your fixture::
...
{
"model": "store.book",
"fields": {
"name": "Mostly Harmless",
"author": ["Douglas", "Adams"]
}
},
...
{
"model": "store.person",
"fields": {
"first_name": "Douglas",
"last_name": "Adams"
}
},
...
In order to handle this situation, you need to pass
``handle_forward_references=True`` to ``serializers.deserialize()``. This will
set the ``deferred_fields`` attribute on the ``DeserializedObject`` instances.
You'll need to keep track of ``DeserializedObject`` instances where this
attribute isn't ``None`` and later call ``save_deferred_fields()`` on them.
Typical usage looks like this::
objs_with_deferred_fields = []
for obj in serializers.deserialize('xml', data, handle_forward_references=True):
obj.save()
if obj.deferred_fields is not None:
objs_with_deferred_fields.append(obj)
for obj in objs_with_deferred_fields:
obj.save_deferred_fields()
For this to work, the ``ForeignKey`` on the referencing model must have
``null=True``.
Dependencies during serialization
---------------------------------
Since natural keys rely on database lookups to resolve references, it
is important that the data exists before it is referenced. You can't make
a "forward reference" with natural keys -- the data you're referencing
must exist before you include a natural key reference to that data.
It's often possible to avoid explicitly having to handle forward references by
taking care with the ordering of objects within a fixture.
To accommodate this limitation, calls to :djadmin:`dumpdata` that use
the :option:`dumpdata --natural-foreign` option will serialize any model with a
``natural_key()`` method before serializing standard primary key objects.
To help with this, calls to :djadmin:`dumpdata` that use the :option:`dumpdata
--natural-foreign` option will serialize any model with a ``natural_key()``
method before serializing standard primary key objects.
However, this may not always be enough. If your natural key refers to
another object (by using a foreign key or natural key to another object

View File

@ -0,0 +1,20 @@
[
{
"model": "fixtures.naturalkeything",
"fields": {
"key": "t1",
"other_thing": [
"t2"
]
}
},
{
"model": "fixtures.naturalkeything",
"fields": {
"key": "t2",
"other_thing": [
"t1"
]
}
}
]

View File

@ -0,0 +1,23 @@
[
{
"model": "fixtures.naturalkeything",
"fields": {
"key": "t1",
"other_things": [
["t2"], ["t3"]
]
}
},
{
"model": "fixtures.naturalkeything",
"fields": {
"key": "t2"
}
},
{
"model": "fixtures.naturalkeything",
"fields": {
"key": "t3"
}
}
]

View File

@ -116,3 +116,21 @@ class Book(models.Model):
class PrimaryKeyUUIDModel(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
class NaturalKeyThing(models.Model):
key = models.CharField(max_length=100)
other_thing = models.ForeignKey('NaturalKeyThing', on_delete=models.CASCADE, null=True)
other_things = models.ManyToManyField('NaturalKeyThing', related_name='thing_m2m_set')
class Manager(models.Manager):
def get_by_natural_key(self, key):
return self.get(key=key)
objects = Manager()
def natural_key(self):
return (self.key,)
def __str__(self):
return self.key

View File

@ -17,7 +17,8 @@ from django.db import IntegrityError, connection
from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
from .models import (
Article, Category, PrimaryKeyUUIDModel, ProxySpy, Spy, Tag, Visa,
Article, Category, NaturalKeyThing, PrimaryKeyUUIDModel, ProxySpy, Spy,
Tag, Visa,
)
@ -780,3 +781,21 @@ class FixtureTransactionTests(DumpDataAssertMixin, TransactionTestCase):
'<Article: Time to reform copyright>',
'<Article: Poker has no place on ESPN>',
])
class ForwardReferenceTests(TestCase):
def test_forward_reference_fk(self):
management.call_command('loaddata', 'forward_reference_fk.json', verbosity=0)
self.assertEqual(NaturalKeyThing.objects.count(), 2)
t1, t2 = NaturalKeyThing.objects.all()
self.assertEqual(t1.other_thing, t2)
self.assertEqual(t2.other_thing, t1)
def test_forward_reference_m2m(self):
management.call_command('loaddata', 'forward_reference_m2m.json', verbosity=0)
self.assertEqual(NaturalKeyThing.objects.count(), 3)
t1 = NaturalKeyThing.objects.get_by_natural_key('t1')
self.assertQuerysetEqual(
t1.other_things.order_by('key'),
['<NaturalKeyThing: t2>', '<NaturalKeyThing: t3>']
)

View File

@ -19,3 +19,21 @@ class NaturalKeyAnchor(models.Model):
class FKDataNaturalKey(models.Model):
data = models.ForeignKey(NaturalKeyAnchor, models.SET_NULL, null=True)
class NaturalKeyThing(models.Model):
key = models.CharField(max_length=100)
other_thing = models.ForeignKey('NaturalKeyThing', on_delete=models.CASCADE, null=True)
other_things = models.ManyToManyField('NaturalKeyThing', related_name='thing_m2m_set')
class Manager(models.Manager):
def get_by_natural_key(self, key):
return self.get(key=key)
objects = Manager()
def natural_key(self):
return (self.key,)
def __str__(self):
return self.key

View File

@ -2,7 +2,7 @@ from django.core import serializers
from django.db import connection
from django.test import TestCase
from .models import Child, FKDataNaturalKey, NaturalKeyAnchor
from .models import Child, FKDataNaturalKey, NaturalKeyAnchor, NaturalKeyThing
from .tests import register_tests
@ -93,7 +93,100 @@ def natural_pk_mti_test(self, format):
self.assertEqual(child.child_data, child.parent_data)
def forward_ref_fk_test(self, format):
t1 = NaturalKeyThing.objects.create(key='t1')
t2 = NaturalKeyThing.objects.create(key='t2', other_thing=t1)
t1.other_thing = t2
t1.save()
string_data = serializers.serialize(
format, [t1, t2], use_natural_primary_keys=True,
use_natural_foreign_keys=True,
)
NaturalKeyThing.objects.all().delete()
objs_with_deferred_fields = []
for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
obj.save()
if obj.deferred_fields:
objs_with_deferred_fields.append(obj)
for obj in objs_with_deferred_fields:
obj.save_deferred_fields()
t1 = NaturalKeyThing.objects.get(key='t1')
t2 = NaturalKeyThing.objects.get(key='t2')
self.assertEqual(t1.other_thing, t2)
self.assertEqual(t2.other_thing, t1)
def forward_ref_fk_with_error_test(self, format):
t1 = NaturalKeyThing.objects.create(key='t1')
t2 = NaturalKeyThing.objects.create(key='t2', other_thing=t1)
t1.other_thing = t2
t1.save()
string_data = serializers.serialize(
format, [t1], use_natural_primary_keys=True,
use_natural_foreign_keys=True,
)
NaturalKeyThing.objects.all().delete()
objs_with_deferred_fields = []
for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
obj.save()
if obj.deferred_fields:
objs_with_deferred_fields.append(obj)
obj = objs_with_deferred_fields[0]
msg = 'NaturalKeyThing matching query does not exist'
with self.assertRaisesMessage(serializers.base.DeserializationError, msg):
obj.save_deferred_fields()
def forward_ref_m2m_test(self, format):
t1 = NaturalKeyThing.objects.create(key='t1')
t2 = NaturalKeyThing.objects.create(key='t2')
t3 = NaturalKeyThing.objects.create(key='t3')
t1.other_things.set([t2, t3])
string_data = serializers.serialize(
format, [t1, t2, t3], use_natural_primary_keys=True,
use_natural_foreign_keys=True,
)
NaturalKeyThing.objects.all().delete()
objs_with_deferred_fields = []
for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
obj.save()
if obj.deferred_fields:
objs_with_deferred_fields.append(obj)
for obj in objs_with_deferred_fields:
obj.save_deferred_fields()
t1 = NaturalKeyThing.objects.get(key='t1')
t2 = NaturalKeyThing.objects.get(key='t2')
t3 = NaturalKeyThing.objects.get(key='t3')
self.assertCountEqual(t1.other_things.all(), [t2, t3])
def forward_ref_m2m_with_error_test(self, format):
t1 = NaturalKeyThing.objects.create(key='t1')
t2 = NaturalKeyThing.objects.create(key='t2')
t3 = NaturalKeyThing.objects.create(key='t3')
t1.other_things.set([t2, t3])
t1.save()
string_data = serializers.serialize(
format, [t1, t2], use_natural_primary_keys=True,
use_natural_foreign_keys=True,
)
NaturalKeyThing.objects.all().delete()
objs_with_deferred_fields = []
for obj in serializers.deserialize(format, string_data, handle_forward_references=True):
obj.save()
if obj.deferred_fields:
objs_with_deferred_fields.append(obj)
obj = objs_with_deferred_fields[0]
msg = 'NaturalKeyThing matching query does not exist'
with self.assertRaisesMessage(serializers.base.DeserializationError, msg):
obj.save_deferred_fields()
# Dynamically register tests for each serializer
register_tests(NaturalKeySerializerTests, 'test_%s_natural_key_serializer', natural_key_serializer_test)
register_tests(NaturalKeySerializerTests, 'test_%s_serializer_natural_keys', natural_key_test)
register_tests(NaturalKeySerializerTests, 'test_%s_serializer_natural_pks_mti', natural_pk_mti_test)
register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_fks', forward_ref_fk_test)
register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_fk_errors', forward_ref_fk_with_error_test)
register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_m2ms', forward_ref_m2m_test)
register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_m2m_errors', forward_ref_m2m_with_error_test)