diff --git a/django/db/models/fields/related_descriptors.py b/django/db/models/fields/related_descriptors.py index 342cbb3465..e3c74cbcfa 100644 --- a/django/db/models/fields/related_descriptors.py +++ b/django/db/models/fields/related_descriptors.py @@ -576,7 +576,16 @@ def create_reverse_many_to_one_manager(superclass, rel): val = getattr(self.instance, field.attname) if val is None or (val == '' and empty_strings_as_null): return queryset.none() - queryset._known_related_objects = {self.field: {self.instance.pk: self.instance}} + if self.field.many_to_one: + # Guard against field-like objects such as GenericRelation + # that abuse create_reverse_many_to_one_manager() with reverse + # one-to-many relationships instead and break known related + # objects assignment. + rel_obj_id = tuple([ + getattr(self.instance, target_field.attname) + for target_field in self.field.get_path_info()[-1].target_fields + ]) + queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}} return queryset def _remove_prefetched_objects(self): diff --git a/django/db/models/query.py b/django/db/models/query.py index 173dcef83a..1755bea1a8 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -61,6 +61,16 @@ class ModelIterable(BaseIterable): init_list = [f[0].target.attname for f in select[model_fields_start:model_fields_end]] related_populators = get_related_populators(klass_info, select, db) + known_related_objects = [ + (field, related_objs, [ + operator.attrgetter( + field.attname + if from_field == 'self' else + queryset.model._meta.get_field(from_field).attname + ) + for from_field in field.from_fields + ]) for field, related_objs in queryset._known_related_objects.items() + ] for row in compiler.results_iter(results): obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end]) for rel_populator in related_populators: @@ -69,19 +79,18 @@ class ModelIterable(BaseIterable): for attr_name, col_pos in annotation_col_map.items(): setattr(obj, attr_name, row[col_pos]) - # Add the known related objects to the model, if there are any - if queryset._known_related_objects: - for field, rel_objs in queryset._known_related_objects.items(): - # Avoid overwriting objects loaded e.g. by select_related - if field.is_cached(obj): - continue - pk = getattr(obj, field.get_attname()) - try: - rel_obj = rel_objs[pk] - except KeyError: - pass # may happen in qs1 | qs2 scenarios - else: - setattr(obj, field.name, rel_obj) + # Add the known related objects to the model. + for field, rel_objs, rel_getters in known_related_objects: + # Avoid overwriting objects loaded by, e.g., select_related(). + if field.is_cached(obj): + continue + rel_obj_id = tuple([rel_getter(obj) for rel_getter in rel_getters]) + try: + rel_obj = rel_objs[rel_obj_id] + except KeyError: + pass # May happen in qs1 | qs2 scenarios. + else: + setattr(obj, field.name, rel_obj) yield obj diff --git a/tests/foreign_object/tests.py b/tests/foreign_object/tests.py index 59d4357802..6af6def6d7 100644 --- a/tests/foreign_object/tests.py +++ b/tests/foreign_object/tests.py @@ -69,12 +69,10 @@ class MultiColumnFKTests(TestCase): membership_country_id=self.soviet_union.id, person_id=self.bob.id, group_id=self.republican.id) - self.assertQuerysetEqual( - self.bob.membership_set.all(), [ - self.cia.id - ], - attrgetter("group_id") - ) + with self.assertNumQueries(1): + membership = self.bob.membership_set.get() + self.assertEqual(membership.group_id, self.cia.id) + self.assertIs(membership.person, self.bob) def test_query_filters_correctly(self): @@ -198,8 +196,11 @@ class MultiColumnFKTests(TestCase): list(p.membership_set.all()) for p in Person.objects.prefetch_related('membership_set').order_by('pk')] - normal_membership_sets = [list(p.membership_set.all()) - for p in Person.objects.order_by('pk')] + with self.assertNumQueries(7): + normal_membership_sets = [ + list(p.membership_set.all()) + for p in Person.objects.order_by('pk') + ] self.assertEqual(membership_sets, normal_membership_sets) def test_m2m_through_forward_returns_valid_members(self): diff --git a/tests/many_to_one/models.py b/tests/many_to_one/models.py index 2f98bebdc0..ef784cfbe2 100644 --- a/tests/many_to_one/models.py +++ b/tests/many_to_one/models.py @@ -71,7 +71,7 @@ class Child(models.Model): class ToFieldChild(models.Model): - parent = models.ForeignKey(Parent, models.CASCADE, to_field='name') + parent = models.ForeignKey(Parent, models.CASCADE, to_field='name', related_name='to_field_children') # Multiple paths to the same model (#7110, #7125) diff --git a/tests/many_to_one/tests.py b/tests/many_to_one/tests.py index b04e6ad77a..28430256dc 100644 --- a/tests/many_to_one/tests.py +++ b/tests/many_to_one/tests.py @@ -672,3 +672,10 @@ class ManyToOneTests(TestCase): child = ToFieldChild.objects.create(parent=parent) with self.assertNumQueries(0): self.assertIs(child.parent, parent) + + def test_reverse_foreign_key_instance_to_field_caching(self): + parent = Parent.objects.create(name='a') + ToFieldChild.objects.create(parent=parent) + child = parent.to_field_children.get() + with self.assertNumQueries(0): + self.assertIs(child.parent, parent)