diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 7034d34867..fbb40d9175 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -249,11 +249,19 @@ class SingleRelatedObjectDescriptor(object): if instance is None: return self try: - return getattr(instance, self.cache_name) + rel_obj = getattr(instance, self.cache_name) except AttributeError: params = {'%s__pk' % self.related.field.name: instance._get_pk_val()} - rel_obj = self.get_query_set(instance=instance).get(**params) + try: + rel_obj = self.get_query_set(instance=instance).get(**params) + except self.related.model.DoesNotExist: + rel_obj = None + else: + setattr(rel_obj, self.related.field.get_cache_name(), instance) setattr(instance, self.cache_name, rel_obj) + if rel_obj is None: + raise self.related.model.DoesNotExist + else: return rel_obj def __set__(self, instance, value): @@ -331,24 +339,27 @@ class ReverseSingleRelatedObjectDescriptor(object): def __get__(self, instance, instance_type=None): if instance is None: return self - try: - return getattr(instance, self.cache_name) + rel_obj = getattr(instance, self.cache_name) except AttributeError: val = getattr(instance, self.field.attname) if val is None: - # If NULL is an allowed value, return it. - if self.field.null: - return None - raise self.field.rel.to.DoesNotExist - other_field = self.field.rel.get_related_field() - if other_field.rel: - params = {'%s__pk' % self.field.rel.field_name: val} + rel_obj = None else: - params = {'%s__exact' % self.field.rel.field_name: val} - qs = self.get_query_set(instance=instance) - rel_obj = qs.get(**params) + other_field = self.field.rel.get_related_field() + if other_field.rel: + params = {'%s__pk' % self.field.rel.field_name: val} + else: + params = {'%s__exact' % self.field.rel.field_name: val} + qs = self.get_query_set(instance=instance) + # Assuming the database enforces foreign keys, this won't fail. + rel_obj = qs.get(**params) + if not self.field.rel.multiple: + setattr(rel_obj, self.field.related.get_cache_name(), instance) setattr(instance, self.cache_name, rel_obj) + if rel_obj is None and not self.field.null: + raise self.field.rel.to.DoesNotExist + else: return rel_obj def __set__(self, instance, value): @@ -385,17 +396,13 @@ class ReverseSingleRelatedObjectDescriptor(object): # populated the cache, then we don't care - we're only accessing # the object to invalidate the accessor cache, so there's no # need to populate the cache just to expire it again. - related = getattr(instance, self.field.get_cache_name(), None) + related = getattr(instance, self.cache_name, None) # If we've got an old related object, we need to clear out its # cache. This cache also might not exist if the related object # hasn't been accessed yet. - if related: - cache_name = self.field.related.get_cache_name() - try: - delattr(related, cache_name) - except AttributeError: - pass + if related is not None: + setattr(related, self.field.related.get_cache_name(), None) # Set the value of the related field try: @@ -405,9 +412,11 @@ class ReverseSingleRelatedObjectDescriptor(object): setattr(instance, self.field.attname, val) # Since we already know what the related object is, seed the related - # object cache now, too. This avoids another db hit if you get the + # object caches now, too. This avoids another db hit if you get the # object you just set. - setattr(instance, self.field.get_cache_name(), value) + setattr(instance, self.cache_name, value) + if value is not None and not self.field.rel.multiple: + setattr(value, self.field.related.get_cache_name(), instance) class ForeignRelatedObjectsDescriptor(object): # This class provides the functionality that makes the related-object diff --git a/tests/regressiontests/one_to_one_regress/tests.py b/tests/regressiontests/one_to_one_regress/tests.py index 88980c2191..eced88598b 100644 --- a/tests/regressiontests/one_to_one_regress/tests.py +++ b/tests/regressiontests/one_to_one_regress/tests.py @@ -132,3 +132,73 @@ class OneToOneRegressionTests(TestCase): Target.objects.exclude(pointer2=None), [] ) + + def test_reverse_object_does_not_exist_cache(self): + """ + Regression for #13839 and #17439. + + DoesNotExist on a reverse one-to-one relation is cached. + """ + p = Place(name='Zombie Cats', address='Not sure') + p.save() + with self.assertNumQueries(1): + with self.assertRaises(Restaurant.DoesNotExist): + p.restaurant + with self.assertNumQueries(0): + with self.assertRaises(Restaurant.DoesNotExist): + p.restaurant + + def test_reverse_object_cached_when_related_is_accessed(self): + """ + Regression for #13839 and #17439. + + The target of a one-to-one relation is cached + when the origin is accessed through the reverse relation. + """ + # Use a fresh object without caches + r = Restaurant.objects.get(pk=self.r1.pk) + p = r.place + with self.assertNumQueries(0): + self.assertEqual(p.restaurant, r) + + def test_related_object_cached_when_reverse_is_accessed(self): + """ + Regression for #13839 and #17439. + + The origin of a one-to-one relation is cached + when the target is accessed through the reverse relation. + """ + # Use a fresh object without caches + p = Place.objects.get(pk=self.p1.pk) + r = p.restaurant + with self.assertNumQueries(0): + self.assertEqual(r.place, p) + + def test_reverse_object_cached_when_related_is_set(self): + """ + Regression for #13839 and #17439. + + The target of a one-to-one relation is always cached. + """ + p = Place(name='Zombie Cats', address='Not sure') + p.save() + self.r1.place = p + self.r1.save() + with self.assertNumQueries(0): + self.assertEqual(p.restaurant, self.r1) + + def test_reverse_object_cached_when_related_is_unset(self): + """ + Regression for #13839 and #17439. + + The target of a one-to-one relation is always cached. + """ + b = UndergroundBar(place=self.p1, serves_cocktails=True) + b.save() + with self.assertNumQueries(0): + self.assertEqual(self.p1.undergroundbar, b) + b.place = None + b.save() + with self.assertNumQueries(0): + with self.assertRaises(UndergroundBar.DoesNotExist): + self.p1.undergroundbar diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py index d57ee90570..1373f04717 100644 --- a/tests/regressiontests/select_related_onetoone/tests.py +++ b/tests/regressiontests/select_related_onetoone/tests.py @@ -79,4 +79,32 @@ class ReverseSelectRelatedTestCase(TestCase): p1 = Product.objects.create(name="Django Plushie", image=im) p2 = Product.objects.create(name="Talking Django Plushie") - self.assertEqual(len(Product.objects.select_related("image")), 2) + with self.assertNumQueries(1): + result = sorted(Product.objects.select_related("image"), key=lambda x: x.name) + self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"]) + + self.assertEqual(p1.image, im) + # Check for ticket #13839 + self.assertIsNone(p2.image) + + def test_missing_reverse(self): + """ + Ticket #13839: select_related() should NOT cache None + for missing objects on a reverse 1-1 relation. + """ + with self.assertNumQueries(1): + user = User.objects.select_related('userprofile').get(username='bob') + with self.assertRaises(UserProfile.DoesNotExist): + user.userprofile + + def test_nullable_missing_reverse(self): + """ + Ticket #13839: select_related() should NOT cache None + for missing objects on a reverse 0-1 relation. + """ + Image.objects.create(name="imag1") + + with self.assertNumQueries(1): + image = Image.objects.select_related('product').get() + with self.assertRaises(Product.DoesNotExist): + image.product