From 312eb5cb11d09c0c41b2740e2e9aef838d60c8b5 Mon Sep 17 00:00:00 2001 From: Peter Inglesby Date: Fri, 13 Jul 2018 22:54:47 +0100 Subject: [PATCH] Fixed #26291 -- Allowed loaddata to handle forward references in natural_key fixtures. --- django/core/management/commands/loaddata.py | 6 ++ django/core/serializers/base.py | 42 +++++++- django/core/serializers/python.py | 18 +++- django/core/serializers/xml_serializer.py | 46 ++++++++- docs/releases/2.2.txt | 5 +- docs/topics/serialization.txt | 65 +++++++++++-- .../fixtures/forward_reference_fk.json | 20 ++++ .../fixtures/forward_reference_m2m.json | 23 +++++ tests/fixtures/models.py | 18 ++++ tests/fixtures/tests.py | 21 +++- tests/serializers/models/natural.py | 18 ++++ tests/serializers/test_natural.py | 95 ++++++++++++++++++- 12 files changed, 352 insertions(+), 25 deletions(-) create mode 100644 tests/fixtures/fixtures/forward_reference_fk.json create mode 100644 tests/fixtures/fixtures/forward_reference_m2m.json diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index 0b6a1e0898..aae156f911 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -109,8 +109,11 @@ class Command(BaseCommand): return with connection.constraint_checks_disabled(): + self.objs_with_deferred_fields = [] for fixture_label in fixture_labels: self.load_label(fixture_label) + for obj in self.objs_with_deferred_fields: + obj.save_deferred_fields(using=self.using) # Since we disabled constraint checks, we must manually check for # any invalid keys that might have been added @@ -163,6 +166,7 @@ class Command(BaseCommand): objects = serializers.deserialize( ser_fmt, fixture, using=self.using, ignorenonexistent=self.ignore, + handle_forward_references=True, ) for obj in objects: @@ -189,6 +193,8 @@ class Command(BaseCommand): 'error_msg': e, },) raise + if obj.deferred_fields: + self.objs_with_deferred_fields.append(obj) if objects and show_progress: self.stdout.write('') # add a newline after progress indicator self.loaded_object_count += loaded_objects_in_fixture diff --git a/django/core/serializers/base.py b/django/core/serializers/base.py index 4ff0c753f8..0d7946dc3f 100644 --- a/django/core/serializers/base.py +++ b/django/core/serializers/base.py @@ -3,8 +3,11 @@ Module for abstract serializer/unserializer base classes. """ from io import StringIO +from django.core.exceptions import ObjectDoesNotExist from django.db import models +DEFER_FIELD = object() + class SerializerDoesNotExist(KeyError): """The requested serializer was not found.""" @@ -201,9 +204,10 @@ class DeserializedObject: (and not touch the many-to-many stuff.) """ - def __init__(self, obj, m2m_data=None): + def __init__(self, obj, m2m_data=None, deferred_fields=None): self.object = obj self.m2m_data = m2m_data + self.deferred_fields = deferred_fields def __repr__(self): return "<%s: %s(pk=%s)>" % ( @@ -225,6 +229,25 @@ class DeserializedObject: # the m2m data twice. self.m2m_data = None + def save_deferred_fields(self, using=None): + self.m2m_data = {} + for field, field_value in self.deferred_fields.items(): + opts = self.object._meta + label = opts.app_label + '.' + opts.model_name + if isinstance(field.remote_field, models.ManyToManyRel): + try: + values = deserialize_m2m_values(field, field_value, using, handle_forward_references=False) + except M2MDeserializationError as e: + raise DeserializationError.WithData(e.original_exc, label, self.object.pk, e.pk) + self.m2m_data[field.name] = values + elif isinstance(field.remote_field, models.ManyToOneRel): + try: + value = deserialize_fk_value(field, field_value, using, handle_forward_references=False) + except Exception as e: + raise DeserializationError.WithData(e, label, self.object.pk, field_value) + setattr(self.object, field.attname, value) + self.save() + def build_instance(Model, data, db): """ @@ -244,7 +267,7 @@ def build_instance(Model, data, db): return obj -def deserialize_m2m_values(field, field_value, using): +def deserialize_m2m_values(field, field_value, using, handle_forward_references): model = field.remote_field.model if hasattr(model._default_manager, 'get_by_natural_key'): def m2m_convert(value): @@ -262,10 +285,13 @@ def deserialize_m2m_values(field, field_value, using): values.append(m2m_convert(pk)) return values except Exception as e: - raise M2MDeserializationError(e, pk) + if isinstance(e, ObjectDoesNotExist) and handle_forward_references: + return DEFER_FIELD + else: + raise M2MDeserializationError(e, pk) -def deserialize_fk_value(field, field_value, using): +def deserialize_fk_value(field, field_value, using, handle_forward_references): if field_value is None: return None model = field.remote_field.model @@ -273,7 +299,13 @@ def deserialize_fk_value(field, field_value, using): field_name = field.remote_field.field_name if (hasattr(default_manager, 'get_by_natural_key') and hasattr(field_value, '__iter__') and not isinstance(field_value, str)): - obj = default_manager.db_manager(using).get_by_natural_key(*field_value) + try: + obj = default_manager.db_manager(using).get_by_natural_key(*field_value) + except ObjectDoesNotExist: + if handle_forward_references: + return DEFER_FIELD + else: + raise value = getattr(obj, field_name) # If this is a natural foreign key to an object that has a FK/O2O as # the foreign key, use the FK value. diff --git a/django/core/serializers/python.py b/django/core/serializers/python.py index 922e2c6a8f..08739c98fc 100644 --- a/django/core/serializers/python.py +++ b/django/core/serializers/python.py @@ -83,6 +83,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False It's expected that you pass the Python objects themselves (instead of a stream or a string) to the constructor """ + handle_forward_references = options.pop('handle_forward_references', False) field_names_cache = {} # Model: for d in object_list: @@ -101,6 +102,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False except Exception as e: raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), None) m2m_data = {} + deferred_fields = {} if Model not in field_names_cache: field_names_cache[Model] = {f.name for f in Model._meta.get_fields()} @@ -118,17 +120,23 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False # Handle M2M relations if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel): try: - values = base.deserialize_m2m_values(field, field_value, using) + values = base.deserialize_m2m_values(field, field_value, using, handle_forward_references) except base.M2MDeserializationError as e: raise base.DeserializationError.WithData(e.original_exc, d['model'], d.get('pk'), e.pk) - m2m_data[field.name] = values + if values == base.DEFER_FIELD: + deferred_fields[field] = field_value + else: + m2m_data[field.name] = values # Handle FK fields elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel): try: - value = base.deserialize_fk_value(field, field_value, using) + value = base.deserialize_fk_value(field, field_value, using, handle_forward_references) except Exception as e: raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value) - data[field.attname] = value + if value == base.DEFER_FIELD: + deferred_fields[field] = field_value + else: + data[field.attname] = value # Handle all other fields else: try: @@ -137,7 +145,7 @@ def Deserializer(object_list, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False raise base.DeserializationError.WithData(e, d['model'], d.get('pk'), field_value) obj = base.build_instance(Model, data, using) - yield base.DeserializedObject(obj, m2m_data) + yield base.DeserializedObject(obj, m2m_data, deferred_fields) def _get_model(model_identifier): diff --git a/django/core/serializers/xml_serializer.py b/django/core/serializers/xml_serializer.py index 076be1fe1f..47c6b88e93 100644 --- a/django/core/serializers/xml_serializer.py +++ b/django/core/serializers/xml_serializer.py @@ -8,6 +8,7 @@ from xml.sax.expatreader import ExpatParser as _ExpatParser from django.apps import apps from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist from django.core.serializers import base from django.db import DEFAULT_DB_ALIAS, models from django.utils.xmlutils import ( @@ -151,6 +152,7 @@ class Deserializer(base.Deserializer): def __init__(self, stream_or_string, *, using=DEFAULT_DB_ALIAS, ignorenonexistent=False, **options): super().__init__(stream_or_string, **options) + self.handle_forward_references = options.pop('handle_forward_references', False) self.event_stream = pulldom.parse(self.stream, self._make_parser()) self.db = using self.ignore = ignorenonexistent @@ -181,6 +183,7 @@ class Deserializer(base.Deserializer): # Also start building a dict of m2m data (this is saved as # {m2m_accessor_attribute : [list_of_related_objects]}) m2m_data = {} + deferred_fields = {} field_names = {f.name for f in Model._meta.get_fields()} # Deserialize each field. @@ -200,9 +203,26 @@ class Deserializer(base.Deserializer): # As is usually the case, relation fields get the special treatment. if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel): - m2m_data[field.name] = self._handle_m2m_field_node(field_node, field) + value = self._handle_m2m_field_node(field_node, field) + if value == base.DEFER_FIELD: + deferred_fields[field] = [ + [ + getInnerText(nat_node).strip() + for nat_node in obj_node.getElementsByTagName('natural') + ] + for obj_node in field_node.getElementsByTagName('object') + ] + else: + m2m_data[field.name] = value elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel): - data[field.attname] = self._handle_fk_field_node(field_node, field) + value = self._handle_fk_field_node(field_node, field) + if value == base.DEFER_FIELD: + deferred_fields[field] = [ + getInnerText(k).strip() + for k in field_node.getElementsByTagName('natural') + ] + else: + data[field.attname] = value else: if field_node.getElementsByTagName('None'): value = None @@ -213,7 +233,7 @@ class Deserializer(base.Deserializer): obj = base.build_instance(Model, data, self.db) # Return a DeserializedObject so that the m2m data has a place to live. - return base.DeserializedObject(obj, m2m_data) + return base.DeserializedObject(obj, m2m_data, deferred_fields) def _handle_fk_field_node(self, node, field): """ @@ -229,7 +249,13 @@ class Deserializer(base.Deserializer): if keys: # If there are 'natural' subelements, it must be a natural key field_value = [getInnerText(k).strip() for k in keys] - obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value) + try: + obj = model._default_manager.db_manager(self.db).get_by_natural_key(*field_value) + except ObjectDoesNotExist: + if self.handle_forward_references: + return base.DEFER_FIELD + else: + raise obj_pk = getattr(obj, field.remote_field.field_name) # If this is a natural foreign key to an object that # has a FK/O2O as the foreign key, use the FK value @@ -264,7 +290,17 @@ class Deserializer(base.Deserializer): else: def m2m_convert(n): return model._meta.pk.to_python(n.getAttribute('pk')) - return [m2m_convert(c) for c in node.getElementsByTagName("object")] + values = [] + try: + for c in node.getElementsByTagName('object'): + values.append(m2m_convert(c)) + except Exception as e: + if isinstance(e, ObjectDoesNotExist) and self.handle_forward_references: + return base.DEFER_FIELD + else: + raise base.M2MDeserializationError(e, c) + else: + return values def _get_model_from_node(self, node, attr): """ diff --git a/docs/releases/2.2.txt b/docs/releases/2.2.txt index 3d1f2783ab..ce8bcd6fa4 100644 --- a/docs/releases/2.2.txt +++ b/docs/releases/2.2.txt @@ -184,7 +184,10 @@ Requests and Responses Serialization ~~~~~~~~~~~~~ -* ... +* You can now deserialize data using natural keys containing :ref:`forward + references ` by passing + ``handle_forward_references=True`` to ``serializers.deserialize()``. + Additionally, :djadmin:`loaddata` handles forward references automatically. Signals ~~~~~~~ diff --git a/docs/topics/serialization.txt b/docs/topics/serialization.txt index 58a00bec20..433fb23718 100644 --- a/docs/topics/serialization.txt +++ b/docs/topics/serialization.txt @@ -514,17 +514,68 @@ command line flags to generate natural keys. natural keys during serialization, but *not* be able to load those key values, just don't define the ``get_by_natural_key()`` method. +.. _natural-keys-and-forward-references: + +Natural keys and forward references +----------------------------------- + +.. versionadded:: 2.2 + +Sometimes when you use :ref:`natural foreign keys +` you'll need to deserialize data where +an object has a foreign key referencing another object that hasn't yet been +deserialized. This is called a "forward reference". + +For instance, suppose you have the following objects in your fixture:: + + ... + { + "model": "store.book", + "fields": { + "name": "Mostly Harmless", + "author": ["Douglas", "Adams"] + } + }, + ... + { + "model": "store.person", + "fields": { + "first_name": "Douglas", + "last_name": "Adams" + } + }, + ... + +In order to handle this situation, you need to pass +``handle_forward_references=True`` to ``serializers.deserialize()``. This will +set the ``deferred_fields`` attribute on the ``DeserializedObject`` instances. +You'll need to keep track of ``DeserializedObject`` instances where this +attribute isn't ``None`` and later call ``save_deferred_fields()`` on them. + +Typical usage looks like this:: + + objs_with_deferred_fields = [] + + for obj in serializers.deserialize('xml', data, handle_forward_references=True): + obj.save() + if obj.deferred_fields is not None: + objs_with_deferred_fields.append(obj) + + for obj in objs_with_deferred_fields: + obj.save_deferred_fields() + +For this to work, the ``ForeignKey`` on the referencing model must have +``null=True``. + Dependencies during serialization --------------------------------- -Since natural keys rely on database lookups to resolve references, it -is important that the data exists before it is referenced. You can't make -a "forward reference" with natural keys -- the data you're referencing -must exist before you include a natural key reference to that data. +It's often possible to avoid explicitly having to handle forward references by +taking care with the ordering of objects within a fixture. -To accommodate this limitation, calls to :djadmin:`dumpdata` that use -the :option:`dumpdata --natural-foreign` option will serialize any model with a -``natural_key()`` method before serializing standard primary key objects. +To help with this, calls to :djadmin:`dumpdata` that use the :option:`dumpdata +--natural-foreign` option will serialize any model with a ``natural_key()`` +method before serializing standard primary key objects. However, this may not always be enough. If your natural key refers to another object (by using a foreign key or natural key to another object diff --git a/tests/fixtures/fixtures/forward_reference_fk.json b/tests/fixtures/fixtures/forward_reference_fk.json new file mode 100644 index 0000000000..d3122c8823 --- /dev/null +++ b/tests/fixtures/fixtures/forward_reference_fk.json @@ -0,0 +1,20 @@ +[ + { + "model": "fixtures.naturalkeything", + "fields": { + "key": "t1", + "other_thing": [ + "t2" + ] + } + }, + { + "model": "fixtures.naturalkeything", + "fields": { + "key": "t2", + "other_thing": [ + "t1" + ] + } + } +] diff --git a/tests/fixtures/fixtures/forward_reference_m2m.json b/tests/fixtures/fixtures/forward_reference_m2m.json new file mode 100644 index 0000000000..9e0f18e294 --- /dev/null +++ b/tests/fixtures/fixtures/forward_reference_m2m.json @@ -0,0 +1,23 @@ +[ + { + "model": "fixtures.naturalkeything", + "fields": { + "key": "t1", + "other_things": [ + ["t2"], ["t3"] + ] + } + }, + { + "model": "fixtures.naturalkeything", + "fields": { + "key": "t2" + } + }, + { + "model": "fixtures.naturalkeything", + "fields": { + "key": "t3" + } + } +] diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 07ef1587d9..be6674fa61 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -116,3 +116,21 @@ class Book(models.Model): class PrimaryKeyUUIDModel(models.Model): id = models.UUIDField(primary_key=True, default=uuid.uuid4) + + +class NaturalKeyThing(models.Model): + key = models.CharField(max_length=100) + other_thing = models.ForeignKey('NaturalKeyThing', on_delete=models.CASCADE, null=True) + other_things = models.ManyToManyField('NaturalKeyThing', related_name='thing_m2m_set') + + class Manager(models.Manager): + def get_by_natural_key(self, key): + return self.get(key=key) + + objects = Manager() + + def natural_key(self): + return (self.key,) + + def __str__(self): + return self.key diff --git a/tests/fixtures/tests.py b/tests/fixtures/tests.py index d48062f8dd..37b21fa296 100644 --- a/tests/fixtures/tests.py +++ b/tests/fixtures/tests.py @@ -17,7 +17,8 @@ from django.db import IntegrityError, connection from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from .models import ( - Article, Category, PrimaryKeyUUIDModel, ProxySpy, Spy, Tag, Visa, + Article, Category, NaturalKeyThing, PrimaryKeyUUIDModel, ProxySpy, Spy, + Tag, Visa, ) @@ -780,3 +781,21 @@ class FixtureTransactionTests(DumpDataAssertMixin, TransactionTestCase): '', '', ]) + + +class ForwardReferenceTests(TestCase): + def test_forward_reference_fk(self): + management.call_command('loaddata', 'forward_reference_fk.json', verbosity=0) + self.assertEqual(NaturalKeyThing.objects.count(), 2) + t1, t2 = NaturalKeyThing.objects.all() + self.assertEqual(t1.other_thing, t2) + self.assertEqual(t2.other_thing, t1) + + def test_forward_reference_m2m(self): + management.call_command('loaddata', 'forward_reference_m2m.json', verbosity=0) + self.assertEqual(NaturalKeyThing.objects.count(), 3) + t1 = NaturalKeyThing.objects.get_by_natural_key('t1') + self.assertQuerysetEqual( + t1.other_things.order_by('key'), + ['', ''] + ) diff --git a/tests/serializers/models/natural.py b/tests/serializers/models/natural.py index b50956692e..cd24b67f0a 100644 --- a/tests/serializers/models/natural.py +++ b/tests/serializers/models/natural.py @@ -19,3 +19,21 @@ class NaturalKeyAnchor(models.Model): class FKDataNaturalKey(models.Model): data = models.ForeignKey(NaturalKeyAnchor, models.SET_NULL, null=True) + + +class NaturalKeyThing(models.Model): + key = models.CharField(max_length=100) + other_thing = models.ForeignKey('NaturalKeyThing', on_delete=models.CASCADE, null=True) + other_things = models.ManyToManyField('NaturalKeyThing', related_name='thing_m2m_set') + + class Manager(models.Manager): + def get_by_natural_key(self, key): + return self.get(key=key) + + objects = Manager() + + def natural_key(self): + return (self.key,) + + def __str__(self): + return self.key diff --git a/tests/serializers/test_natural.py b/tests/serializers/test_natural.py index 776f554a1c..7e9e9382da 100644 --- a/tests/serializers/test_natural.py +++ b/tests/serializers/test_natural.py @@ -2,7 +2,7 @@ from django.core import serializers from django.db import connection from django.test import TestCase -from .models import Child, FKDataNaturalKey, NaturalKeyAnchor +from .models import Child, FKDataNaturalKey, NaturalKeyAnchor, NaturalKeyThing from .tests import register_tests @@ -93,7 +93,100 @@ def natural_pk_mti_test(self, format): self.assertEqual(child.child_data, child.parent_data) +def forward_ref_fk_test(self, format): + t1 = NaturalKeyThing.objects.create(key='t1') + t2 = NaturalKeyThing.objects.create(key='t2', other_thing=t1) + t1.other_thing = t2 + t1.save() + string_data = serializers.serialize( + format, [t1, t2], use_natural_primary_keys=True, + use_natural_foreign_keys=True, + ) + NaturalKeyThing.objects.all().delete() + objs_with_deferred_fields = [] + for obj in serializers.deserialize(format, string_data, handle_forward_references=True): + obj.save() + if obj.deferred_fields: + objs_with_deferred_fields.append(obj) + for obj in objs_with_deferred_fields: + obj.save_deferred_fields() + t1 = NaturalKeyThing.objects.get(key='t1') + t2 = NaturalKeyThing.objects.get(key='t2') + self.assertEqual(t1.other_thing, t2) + self.assertEqual(t2.other_thing, t1) + + +def forward_ref_fk_with_error_test(self, format): + t1 = NaturalKeyThing.objects.create(key='t1') + t2 = NaturalKeyThing.objects.create(key='t2', other_thing=t1) + t1.other_thing = t2 + t1.save() + string_data = serializers.serialize( + format, [t1], use_natural_primary_keys=True, + use_natural_foreign_keys=True, + ) + NaturalKeyThing.objects.all().delete() + objs_with_deferred_fields = [] + for obj in serializers.deserialize(format, string_data, handle_forward_references=True): + obj.save() + if obj.deferred_fields: + objs_with_deferred_fields.append(obj) + obj = objs_with_deferred_fields[0] + msg = 'NaturalKeyThing matching query does not exist' + with self.assertRaisesMessage(serializers.base.DeserializationError, msg): + obj.save_deferred_fields() + + +def forward_ref_m2m_test(self, format): + t1 = NaturalKeyThing.objects.create(key='t1') + t2 = NaturalKeyThing.objects.create(key='t2') + t3 = NaturalKeyThing.objects.create(key='t3') + t1.other_things.set([t2, t3]) + string_data = serializers.serialize( + format, [t1, t2, t3], use_natural_primary_keys=True, + use_natural_foreign_keys=True, + ) + NaturalKeyThing.objects.all().delete() + objs_with_deferred_fields = [] + for obj in serializers.deserialize(format, string_data, handle_forward_references=True): + obj.save() + if obj.deferred_fields: + objs_with_deferred_fields.append(obj) + for obj in objs_with_deferred_fields: + obj.save_deferred_fields() + t1 = NaturalKeyThing.objects.get(key='t1') + t2 = NaturalKeyThing.objects.get(key='t2') + t3 = NaturalKeyThing.objects.get(key='t3') + self.assertCountEqual(t1.other_things.all(), [t2, t3]) + + +def forward_ref_m2m_with_error_test(self, format): + t1 = NaturalKeyThing.objects.create(key='t1') + t2 = NaturalKeyThing.objects.create(key='t2') + t3 = NaturalKeyThing.objects.create(key='t3') + t1.other_things.set([t2, t3]) + t1.save() + string_data = serializers.serialize( + format, [t1, t2], use_natural_primary_keys=True, + use_natural_foreign_keys=True, + ) + NaturalKeyThing.objects.all().delete() + objs_with_deferred_fields = [] + for obj in serializers.deserialize(format, string_data, handle_forward_references=True): + obj.save() + if obj.deferred_fields: + objs_with_deferred_fields.append(obj) + obj = objs_with_deferred_fields[0] + msg = 'NaturalKeyThing matching query does not exist' + with self.assertRaisesMessage(serializers.base.DeserializationError, msg): + obj.save_deferred_fields() + + # Dynamically register tests for each serializer register_tests(NaturalKeySerializerTests, 'test_%s_natural_key_serializer', natural_key_serializer_test) register_tests(NaturalKeySerializerTests, 'test_%s_serializer_natural_keys', natural_key_test) register_tests(NaturalKeySerializerTests, 'test_%s_serializer_natural_pks_mti', natural_pk_mti_test) +register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_fks', forward_ref_fk_test) +register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_fk_errors', forward_ref_fk_with_error_test) +register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_m2ms', forward_ref_m2m_test) +register_tests(NaturalKeySerializerTests, 'test_%s_forward_references_m2m_errors', forward_ref_m2m_with_error_test)