From 83f5478225588f31e7cbbfed63a4a2b936abc03f Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Fri, 5 Apr 2024 23:08:49 -0400 Subject: [PATCH] Fixed #35356 -- Deferred self-referential foreign key fields adequately. While refs #34612 surfaced issues with reverse one-to-one fields deferrals, it missed that switching to storing remote fields would break self-referential relationships. This change switches to storing related objects in the select mask instead of remote fields to prevent collisions when dealing with self-referential relationships that might have a different directional mask. Despite fixing #21204 introduced a crash under some self-referential deferral conditions, it was simply not working even before that as it aggregated the sets of deferred fields by model. Thanks Joshua van Besouw for the report and Mariusz Felisiak for the review. --- django/db/models/sql/compiler.py | 8 ++++---- django/db/models/sql/query.py | 18 +++++------------- tests/defer_regress/models.py | 6 ++++++ tests/defer_regress/tests.py | 21 +++++++++++++++++++++ 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 676625df6fe..7541817ba09 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -1253,21 +1253,21 @@ class SQLCompiler: if restricted: related_fields = [ - (o.field, o.related_model) + (o, o.field, o.related_model) for o in opts.related_objects if o.field.unique and not o.many_to_many ] - for related_field, model in related_fields: - related_select_mask = select_mask.get(related_field) or {} + for related_object, related_field, model in related_fields: if not select_related_descend( related_field, restricted, requested, - related_select_mask, + select_mask, reverse=True, ): continue + related_select_mask = select_mask.get(related_object) or {} related_field_name = related_field.related_query_name() fields_found.add(related_field_name) diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index b3f130c0b44..c5c58b1788c 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -815,19 +815,17 @@ class Query(BaseExpression): if filtered_relation := self._filtered_relations.get(field_name): relation = opts.get_field(filtered_relation.relation_name) field_select_mask = select_mask.setdefault((field_name, relation), {}) - field = relation.field else: - reverse_rel = opts.get_field(field_name) + relation = opts.get_field(field_name) # While virtual fields such as many-to-many and generic foreign # keys cannot be effectively deferred we've historically # allowed them to be passed to QuerySet.defer(). Ignore such # field references until a layer of validation at mask # alteration time will be implemented eventually. - if not hasattr(reverse_rel, "field"): + if not hasattr(relation, "field"): continue - field = reverse_rel.field - field_select_mask = select_mask.setdefault(field, {}) - related_model = field.model._meta.concrete_model + field_select_mask = select_mask.setdefault(relation, {}) + related_model = relation.related_model._meta.concrete_model self._get_defer_select_mask( related_model._meta, field_mask, field_select_mask ) @@ -840,13 +838,7 @@ class Query(BaseExpression): # Only include fields mentioned in the mask. for field_name, field_mask in mask.items(): field = opts.get_field(field_name) - # Retrieve the actual field associated with reverse relationships - # as that's what is expected in the select mask. - if field in opts.related_objects: - field_key = field.field - else: - field_key = field - field_select_mask = select_mask.setdefault(field_key, {}) + field_select_mask = select_mask.setdefault(field, {}) if field_mask: if not field.is_relation: raise FieldError(next(iter(field_mask))) diff --git a/tests/defer_regress/models.py b/tests/defer_regress/models.py index dd492993b73..38ba4a622f6 100644 --- a/tests/defer_regress/models.py +++ b/tests/defer_regress/models.py @@ -10,6 +10,12 @@ class Item(models.Model): text = models.TextField(default="xyzzy") value = models.IntegerField() other_value = models.IntegerField(default=0) + source = models.OneToOneField( + "self", + related_name="destination", + on_delete=models.CASCADE, + null=True, + ) class RelatedItem(models.Model): diff --git a/tests/defer_regress/tests.py b/tests/defer_regress/tests.py index 10100e348db..1209325f219 100644 --- a/tests/defer_regress/tests.py +++ b/tests/defer_regress/tests.py @@ -309,6 +309,27 @@ class DeferRegressionTest(TestCase): with self.assertNumQueries(1): self.assertEqual(Item.objects.only("request").get(), item) + def test_self_referential_one_to_one(self): + first = Item.objects.create(name="first", value=1) + second = Item.objects.create(name="second", value=2, source=first) + with self.assertNumQueries(1): + deferred_first, deferred_second = ( + Item.objects.select_related("source", "destination") + .only("name", "source__name", "destination__value") + .order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(deferred_first.name, first.name) + self.assertEqual(deferred_second.name, second.name) + self.assertEqual(deferred_second.source.name, first.name) + self.assertEqual(deferred_first.destination.value, second.value) + with self.assertNumQueries(1): + self.assertEqual(deferred_first.value, first.value) + with self.assertNumQueries(1): + self.assertEqual(deferred_second.source.value, first.value) + with self.assertNumQueries(1): + self.assertEqual(deferred_first.destination.name, second.name) + class DeferDeletionSignalsTests(TestCase): senders = [Item, Proxy]