Made the caching of related and reverse related objects consistent in OneToOneFields. Fixed #13839. Refs #17439.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@17890 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Aymeric Augustin 2012-04-10 12:29:25 +00:00
parent 1f11069aa5
commit b90d4e5b74
3 changed files with 131 additions and 24 deletions

View File

@ -249,11 +249,19 @@ class SingleRelatedObjectDescriptor(object):
if instance is None: if instance is None:
return self return self
try: try:
return getattr(instance, self.cache_name) rel_obj = getattr(instance, self.cache_name)
except AttributeError: except AttributeError:
params = {'%s__pk' % self.related.field.name: instance._get_pk_val()} params = {'%s__pk' % self.related.field.name: instance._get_pk_val()}
try:
rel_obj = self.get_query_set(instance=instance).get(**params) rel_obj = self.get_query_set(instance=instance).get(**params)
except self.related.model.DoesNotExist:
rel_obj = None
else:
setattr(rel_obj, self.related.field.get_cache_name(), instance)
setattr(instance, self.cache_name, rel_obj) setattr(instance, self.cache_name, rel_obj)
if rel_obj is None:
raise self.related.model.DoesNotExist
else:
return rel_obj return rel_obj
def __set__(self, instance, value): def __set__(self, instance, value):
@ -331,24 +339,27 @@ class ReverseSingleRelatedObjectDescriptor(object):
def __get__(self, instance, instance_type=None): def __get__(self, instance, instance_type=None):
if instance is None: if instance is None:
return self return self
try: try:
return getattr(instance, self.cache_name) rel_obj = getattr(instance, self.cache_name)
except AttributeError: except AttributeError:
val = getattr(instance, self.field.attname) val = getattr(instance, self.field.attname)
if val is None: if val is None:
# If NULL is an allowed value, return it. rel_obj = None
if self.field.null: else:
return None
raise self.field.rel.to.DoesNotExist
other_field = self.field.rel.get_related_field() other_field = self.field.rel.get_related_field()
if other_field.rel: if other_field.rel:
params = {'%s__pk' % self.field.rel.field_name: val} params = {'%s__pk' % self.field.rel.field_name: val}
else: else:
params = {'%s__exact' % self.field.rel.field_name: val} params = {'%s__exact' % self.field.rel.field_name: val}
qs = self.get_query_set(instance=instance) qs = self.get_query_set(instance=instance)
# Assuming the database enforces foreign keys, this won't fail.
rel_obj = qs.get(**params) rel_obj = qs.get(**params)
if not self.field.rel.multiple:
setattr(rel_obj, self.field.related.get_cache_name(), instance)
setattr(instance, self.cache_name, rel_obj) setattr(instance, self.cache_name, rel_obj)
if rel_obj is None and not self.field.null:
raise self.field.rel.to.DoesNotExist
else:
return rel_obj return rel_obj
def __set__(self, instance, value): def __set__(self, instance, value):
@ -385,17 +396,13 @@ class ReverseSingleRelatedObjectDescriptor(object):
# populated the cache, then we don't care - we're only accessing # populated the cache, then we don't care - we're only accessing
# the object to invalidate the accessor cache, so there's no # the object to invalidate the accessor cache, so there's no
# need to populate the cache just to expire it again. # need to populate the cache just to expire it again.
related = getattr(instance, self.field.get_cache_name(), None) related = getattr(instance, self.cache_name, None)
# If we've got an old related object, we need to clear out its # If we've got an old related object, we need to clear out its
# cache. This cache also might not exist if the related object # cache. This cache also might not exist if the related object
# hasn't been accessed yet. # hasn't been accessed yet.
if related: if related is not None:
cache_name = self.field.related.get_cache_name() setattr(related, self.field.related.get_cache_name(), None)
try:
delattr(related, cache_name)
except AttributeError:
pass
# Set the value of the related field # Set the value of the related field
try: try:
@ -405,9 +412,11 @@ class ReverseSingleRelatedObjectDescriptor(object):
setattr(instance, self.field.attname, val) setattr(instance, self.field.attname, val)
# Since we already know what the related object is, seed the related # Since we already know what the related object is, seed the related
# object cache now, too. This avoids another db hit if you get the # object caches now, too. This avoids another db hit if you get the
# object you just set. # object you just set.
setattr(instance, self.field.get_cache_name(), value) setattr(instance, self.cache_name, value)
if value is not None and not self.field.rel.multiple:
setattr(value, self.field.related.get_cache_name(), instance)
class ForeignRelatedObjectsDescriptor(object): class ForeignRelatedObjectsDescriptor(object):
# This class provides the functionality that makes the related-object # This class provides the functionality that makes the related-object

View File

@ -132,3 +132,73 @@ class OneToOneRegressionTests(TestCase):
Target.objects.exclude(pointer2=None), Target.objects.exclude(pointer2=None),
[] []
) )
def test_reverse_object_does_not_exist_cache(self):
"""
Regression for #13839 and #17439.
DoesNotExist on a reverse one-to-one relation is cached.
"""
p = Place(name='Zombie Cats', address='Not sure')
p.save()
with self.assertNumQueries(1):
with self.assertRaises(Restaurant.DoesNotExist):
p.restaurant
with self.assertNumQueries(0):
with self.assertRaises(Restaurant.DoesNotExist):
p.restaurant
def test_reverse_object_cached_when_related_is_accessed(self):
"""
Regression for #13839 and #17439.
The target of a one-to-one relation is cached
when the origin is accessed through the reverse relation.
"""
# Use a fresh object without caches
r = Restaurant.objects.get(pk=self.r1.pk)
p = r.place
with self.assertNumQueries(0):
self.assertEqual(p.restaurant, r)
def test_related_object_cached_when_reverse_is_accessed(self):
"""
Regression for #13839 and #17439.
The origin of a one-to-one relation is cached
when the target is accessed through the reverse relation.
"""
# Use a fresh object without caches
p = Place.objects.get(pk=self.p1.pk)
r = p.restaurant
with self.assertNumQueries(0):
self.assertEqual(r.place, p)
def test_reverse_object_cached_when_related_is_set(self):
"""
Regression for #13839 and #17439.
The target of a one-to-one relation is always cached.
"""
p = Place(name='Zombie Cats', address='Not sure')
p.save()
self.r1.place = p
self.r1.save()
with self.assertNumQueries(0):
self.assertEqual(p.restaurant, self.r1)
def test_reverse_object_cached_when_related_is_unset(self):
"""
Regression for #13839 and #17439.
The target of a one-to-one relation is always cached.
"""
b = UndergroundBar(place=self.p1, serves_cocktails=True)
b.save()
with self.assertNumQueries(0):
self.assertEqual(self.p1.undergroundbar, b)
b.place = None
b.save()
with self.assertNumQueries(0):
with self.assertRaises(UndergroundBar.DoesNotExist):
self.p1.undergroundbar

View File

@ -79,4 +79,32 @@ class ReverseSelectRelatedTestCase(TestCase):
p1 = Product.objects.create(name="Django Plushie", image=im) p1 = Product.objects.create(name="Django Plushie", image=im)
p2 = Product.objects.create(name="Talking Django Plushie") p2 = Product.objects.create(name="Talking Django Plushie")
self.assertEqual(len(Product.objects.select_related("image")), 2) with self.assertNumQueries(1):
result = sorted(Product.objects.select_related("image"), key=lambda x: x.name)
self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"])
self.assertEqual(p1.image, im)
# Check for ticket #13839
self.assertIsNone(p2.image)
def test_missing_reverse(self):
"""
Ticket #13839: select_related() should NOT cache None
for missing objects on a reverse 1-1 relation.
"""
with self.assertNumQueries(1):
user = User.objects.select_related('userprofile').get(username='bob')
with self.assertRaises(UserProfile.DoesNotExist):
user.userprofile
def test_nullable_missing_reverse(self):
"""
Ticket #13839: select_related() should NOT cache None
for missing objects on a reverse 0-1 relation.
"""
Image.objects.create(name="imag1")
with self.assertNumQueries(1):
image = Image.objects.select_related('product').get()
with self.assertRaises(Product.DoesNotExist):
image.product