Fixed #29908 -- Fixed setting of foreign key after related set access if ForeignKey uses to_field.

Adjusted known related objects handling of target fields which relies on
from and to_fields and has the side effect of fixing a bug bug causing
N+1 queries when using reverse foreign objects.

Thanks Carsten Fuchs for the report.
This commit is contained in:
Simon Charette 2018-11-01 01:10:29 -04:00 committed by Tim Graham
parent 413583e2e2
commit 75dfa92a05
5 changed files with 49 additions and 23 deletions

View File

@ -576,7 +576,16 @@ def create_reverse_many_to_one_manager(superclass, rel):
val = getattr(self.instance, field.attname) val = getattr(self.instance, field.attname)
if val is None or (val == '' and empty_strings_as_null): if val is None or (val == '' and empty_strings_as_null):
return queryset.none() return queryset.none()
queryset._known_related_objects = {self.field: {self.instance.pk: self.instance}} if self.field.many_to_one:
# Guard against field-like objects such as GenericRelation
# that abuse create_reverse_many_to_one_manager() with reverse
# one-to-many relationships instead and break known related
# objects assignment.
rel_obj_id = tuple([
getattr(self.instance, target_field.attname)
for target_field in self.field.get_path_info()[-1].target_fields
])
queryset._known_related_objects = {self.field: {rel_obj_id: self.instance}}
return queryset return queryset
def _remove_prefetched_objects(self): def _remove_prefetched_objects(self):

View File

@ -61,6 +61,16 @@ class ModelIterable(BaseIterable):
init_list = [f[0].target.attname init_list = [f[0].target.attname
for f in select[model_fields_start:model_fields_end]] for f in select[model_fields_start:model_fields_end]]
related_populators = get_related_populators(klass_info, select, db) related_populators = get_related_populators(klass_info, select, db)
known_related_objects = [
(field, related_objs, [
operator.attrgetter(
field.attname
if from_field == 'self' else
queryset.model._meta.get_field(from_field).attname
)
for from_field in field.from_fields
]) for field, related_objs in queryset._known_related_objects.items()
]
for row in compiler.results_iter(results): for row in compiler.results_iter(results):
obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end]) obj = model_cls.from_db(db, init_list, row[model_fields_start:model_fields_end])
for rel_populator in related_populators: for rel_populator in related_populators:
@ -69,17 +79,16 @@ class ModelIterable(BaseIterable):
for attr_name, col_pos in annotation_col_map.items(): for attr_name, col_pos in annotation_col_map.items():
setattr(obj, attr_name, row[col_pos]) setattr(obj, attr_name, row[col_pos])
# Add the known related objects to the model, if there are any # Add the known related objects to the model.
if queryset._known_related_objects: for field, rel_objs, rel_getters in known_related_objects:
for field, rel_objs in queryset._known_related_objects.items(): # Avoid overwriting objects loaded by, e.g., select_related().
# Avoid overwriting objects loaded e.g. by select_related
if field.is_cached(obj): if field.is_cached(obj):
continue continue
pk = getattr(obj, field.get_attname()) rel_obj_id = tuple([rel_getter(obj) for rel_getter in rel_getters])
try: try:
rel_obj = rel_objs[pk] rel_obj = rel_objs[rel_obj_id]
except KeyError: except KeyError:
pass # may happen in qs1 | qs2 scenarios pass # May happen in qs1 | qs2 scenarios.
else: else:
setattr(obj, field.name, rel_obj) setattr(obj, field.name, rel_obj)

View File

@ -69,12 +69,10 @@ class MultiColumnFKTests(TestCase):
membership_country_id=self.soviet_union.id, person_id=self.bob.id, membership_country_id=self.soviet_union.id, person_id=self.bob.id,
group_id=self.republican.id) group_id=self.republican.id)
self.assertQuerysetEqual( with self.assertNumQueries(1):
self.bob.membership_set.all(), [ membership = self.bob.membership_set.get()
self.cia.id self.assertEqual(membership.group_id, self.cia.id)
], self.assertIs(membership.person, self.bob)
attrgetter("group_id")
)
def test_query_filters_correctly(self): def test_query_filters_correctly(self):
@ -198,8 +196,11 @@ class MultiColumnFKTests(TestCase):
list(p.membership_set.all()) list(p.membership_set.all())
for p in Person.objects.prefetch_related('membership_set').order_by('pk')] for p in Person.objects.prefetch_related('membership_set').order_by('pk')]
normal_membership_sets = [list(p.membership_set.all()) with self.assertNumQueries(7):
for p in Person.objects.order_by('pk')] normal_membership_sets = [
list(p.membership_set.all())
for p in Person.objects.order_by('pk')
]
self.assertEqual(membership_sets, normal_membership_sets) self.assertEqual(membership_sets, normal_membership_sets)
def test_m2m_through_forward_returns_valid_members(self): def test_m2m_through_forward_returns_valid_members(self):

View File

@ -71,7 +71,7 @@ class Child(models.Model):
class ToFieldChild(models.Model): class ToFieldChild(models.Model):
parent = models.ForeignKey(Parent, models.CASCADE, to_field='name') parent = models.ForeignKey(Parent, models.CASCADE, to_field='name', related_name='to_field_children')
# Multiple paths to the same model (#7110, #7125) # Multiple paths to the same model (#7110, #7125)

View File

@ -672,3 +672,10 @@ class ManyToOneTests(TestCase):
child = ToFieldChild.objects.create(parent=parent) child = ToFieldChild.objects.create(parent=parent)
with self.assertNumQueries(0): with self.assertNumQueries(0):
self.assertIs(child.parent, parent) self.assertIs(child.parent, parent)
def test_reverse_foreign_key_instance_to_field_caching(self):
parent = Parent.objects.create(name='a')
ToFieldChild.objects.create(parent=parent)
child = parent.to_field_children.get()
with self.assertNumQueries(0):
self.assertIs(child.parent, parent)