Switch several assertNumQueries to use the context manager, which is much more beautiful.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@16986 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2011-10-14 17:03:08 +00:00
parent 6c91521902
commit 69e1e6187a
6 changed files with 35 additions and 60 deletions

View File

@ -1,9 +1,10 @@
from __future__ import absolute_import from __future__ import with_statement, absolute_import
from django.test import TestCase from django.test import TestCase
from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species from .models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
class SelectRelatedTests(TestCase): class SelectRelatedTests(TestCase):
def create_tree(self, stringtree): def create_tree(self, stringtree):
@ -41,29 +42,27 @@ class SelectRelatedTests(TestCase):
""" """
Normally, accessing FKs doesn't fill in related objects Normally, accessing FKs doesn't fill in related objects
""" """
def test(): with self.assertNumQueries(8):
fly = Species.objects.get(name="melanogaster") fly = Species.objects.get(name="melanogaster")
domain = fly.genus.family.order.klass.phylum.kingdom.domain domain = fly.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(domain.name, 'Eukaryota') self.assertEqual(domain.name, 'Eukaryota')
self.assertNumQueries(8, test)
def test_access_fks_with_select_related(self): def test_access_fks_with_select_related(self):
""" """
A select_related() call will fill in those related objects without any A select_related() call will fill in those related objects without any
extra queries extra queries
""" """
def test(): with self.assertNumQueries(1):
person = Species.objects.select_related(depth=10).get(name="sapiens") person = Species.objects.select_related(depth=10).get(name="sapiens")
domain = person.genus.family.order.klass.phylum.kingdom.domain domain = person.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(domain.name, 'Eukaryota') self.assertEqual(domain.name, 'Eukaryota')
self.assertNumQueries(1, test)
def test_list_without_select_related(self): def test_list_without_select_related(self):
""" """
select_related() also of course applies to entire lists, not just select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior without select_related. items. This test verifies the expected behavior without select_related.
""" """
def test(): with self.assertNumQueries(9):
world = Species.objects.all() world = Species.objects.all()
families = [o.genus.family.name for o in world] families = [o.genus.family.name for o in world]
self.assertEqual(sorted(families), [ self.assertEqual(sorted(families), [
@ -72,14 +71,13 @@ class SelectRelatedTests(TestCase):
'Fabaceae', 'Fabaceae',
'Hominidae', 'Hominidae',
]) ])
self.assertNumQueries(9, test)
def test_list_with_select_related(self): def test_list_with_select_related(self):
""" """
select_related() also of course applies to entire lists, not just select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior with select_related. items. This test verifies the expected behavior with select_related.
""" """
def test(): with self.assertNumQueries(1):
world = Species.objects.all().select_related() world = Species.objects.all().select_related()
families = [o.genus.family.name for o in world] families = [o.genus.family.name for o in world]
self.assertEqual(sorted(families), [ self.assertEqual(sorted(families), [
@ -88,21 +86,19 @@ class SelectRelatedTests(TestCase):
'Fabaceae', 'Fabaceae',
'Hominidae', 'Hominidae',
]) ])
self.assertNumQueries(1, test)
def test_depth(self, depth=1, expected=7): def test_depth(self, depth=1, expected=7):
""" """
The "depth" argument to select_related() will stop the descent at a The "depth" argument to select_related() will stop the descent at a
particular level. particular level.
""" """
def test(): # Notice: one fewer queries than above because of depth=1
with self.assertNumQueries(expected):
pea = Species.objects.select_related(depth=depth).get(name="sativum") pea = Species.objects.select_related(depth=depth).get(name="sativum")
self.assertEqual( self.assertEqual(
pea.genus.family.order.klass.phylum.kingdom.domain.name, pea.genus.family.order.klass.phylum.kingdom.domain.name,
'Eukaryota' 'Eukaryota'
) )
# Notice: one fewer queries than above because of depth=1
self.assertNumQueries(expected, test)
def test_larger_depth(self): def test_larger_depth(self):
""" """
@ -116,12 +112,11 @@ class SelectRelatedTests(TestCase):
The "depth" argument to select_related() will stop the descent at a The "depth" argument to select_related() will stop the descent at a
particular level. This can be used on lists as well. particular level. This can be used on lists as well.
""" """
def test(): with self.assertNumQueries(5):
world = Species.objects.all().select_related(depth=2) world = Species.objects.all().select_related(depth=2)
orders = [o.genus.family.order.name for o in world] orders = [o.genus.family.order.name for o in world]
self.assertEqual(sorted(orders), self.assertEqual(sorted(orders),
['Agaricales', 'Diptera', 'Fabales', 'Primates']) ['Agaricales', 'Diptera', 'Fabales', 'Primates'])
self.assertNumQueries(5, test)
def test_select_related_with_extra(self): def test_select_related_with_extra(self):
s = Species.objects.all().select_related(depth=1)\ s = Species.objects.all().select_related(depth=1)\
@ -137,31 +132,28 @@ class SelectRelatedTests(TestCase):
In this case, we explicitly say to select the 'genus' and In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before. 'genus.family' models, leading to the same number of queries as before.
""" """
def test(): with self.assertNumQueries(1):
world = Species.objects.select_related('genus__family') world = Species.objects.select_related('genus__family')
families = [o.genus.family.name for o in world] families = [o.genus.family.name for o in world]
self.assertEqual(sorted(families), self.assertEqual(sorted(families),
['Amanitacae', 'Drosophilidae', 'Fabaceae', 'Hominidae']) ['Amanitacae', 'Drosophilidae', 'Fabaceae', 'Hominidae'])
self.assertNumQueries(1, test)
def test_more_certain_fields(self): def test_more_certain_fields(self):
""" """
In this case, we explicitly say to select the 'genus' and In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before. 'genus.family' models, leading to the same number of queries as before.
""" """
def test(): with self.assertNumQueries(2):
world = Species.objects.filter(genus__name='Amanita')\ world = Species.objects.filter(genus__name='Amanita')\
.select_related('genus__family') .select_related('genus__family')
orders = [o.genus.family.order.name for o in world] orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders, [u'Agaricales']) self.assertEqual(orders, [u'Agaricales'])
self.assertNumQueries(2, test)
def test_field_traversal(self): def test_field_traversal(self):
def test(): with self.assertNumQueries(1):
s = Species.objects.all().select_related('genus__family__order' s = Species.objects.all().select_related('genus__family__order'
).order_by('id')[0:1].get().genus.family.order.name ).order_by('id')[0:1].get().genus.family.order.name
self.assertEqual(s, u'Diptera') self.assertEqual(s, u'Diptera')
self.assertNumQueries(1, test)
def test_depth_fields_fails(self): def test_depth_fields_fails(self):
self.assertRaises(TypeError, self.assertRaises(TypeError,

View File

@ -58,26 +58,23 @@ class GetUniqueCheckTests(unittest.TestCase):
class PerformUniqueChecksTest(TestCase): class PerformUniqueChecksTest(TestCase):
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self): def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
# Regression test for #12560 # Regression test for #12560
def test(): with self.assertNumQueries(0):
mtv = ModelToValidate(number=10, name='Some Name') mtv = ModelToValidate(number=10, name='Some Name')
setattr(mtv, '_adding', True) setattr(mtv, '_adding', True)
mtv.full_clean() mtv.full_clean()
self.assertNumQueries(0, test)
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self): def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
# Regression test for #12560 # Regression test for #12560
def test(): with self.assertNumQueries(1):
mtv = ModelToValidate(number=10, name='Some Name', id=123) mtv = ModelToValidate(number=10, name='Some Name', id=123)
setattr(mtv, '_adding', True) setattr(mtv, '_adding', True)
mtv.full_clean() mtv.full_clean()
self.assertNumQueries(1, test)
def test_primary_key_unique_check_not_performed_when_not_adding(self): def test_primary_key_unique_check_not_performed_when_not_adding(self):
# Regression test for #12132 # Regression test for #12132
def test(): with self.assertNumQueries(0):
mtv = ModelToValidate(number=10, name='Some Name') mtv = ModelToValidate(number=10, name='Some Name')
mtv.full_clean() mtv.full_clean()
self.assertNumQueries(0, test)
def test_unique_for_date(self): def test_unique_for_date(self):
p1 = Post.objects.create(title="Django 1.0 is released", p1 = Post.objects.create(title="Django 1.0 is released",

View File

@ -45,9 +45,8 @@ class CommentTemplateTagTests(CommentTestCase):
self.testRenderCommentForm("{% render_comment_form for a %}") self.testRenderCommentForm("{% render_comment_form for a %}")
def testRenderCommentFormFromObjectWithQueryCount(self): def testRenderCommentFormFromObjectWithQueryCount(self):
def test(): with self.assertNumQueries(1):
self.testRenderCommentFormFromObject() self.testRenderCommentFormFromObject()
self.assertNumQueries(1, test)
def verifyGetCommentCount(self, tag=None): def verifyGetCommentCount(self, tag=None):
t = "{% load comments %}" + (tag or "{% get_comment_count for comment_tests.article a.id as cc %}") + "{{ cc }}" t = "{% load comments %}" + (tag or "{% get_comment_count for comment_tests.article a.id as cc %}") + "{{ cc }}"

View File

@ -1,4 +1,4 @@
from __future__ import absolute_import from __future__ import with_statement, absolute_import
from operator import attrgetter from operator import attrgetter
@ -21,22 +21,18 @@ class DeferRegressionTest(TestCase):
obj = Item.objects.only("name", "other_value").get(name="first") obj = Item.objects.only("name", "other_value").get(name="first")
# Accessing "name" doesn't trigger a new database query. Accessing # Accessing "name" doesn't trigger a new database query. Accessing
# "value" or "text" should. # "value" or "text" should.
def test(): with self.assertNumQueries(0):
self.assertEqual(obj.name, "first") self.assertEqual(obj.name, "first")
self.assertEqual(obj.other_value, 0) self.assertEqual(obj.other_value, 0)
self.assertNumQueries(0, test)
def test(): with self.assertNumQueries(1):
self.assertEqual(obj.value, 42) self.assertEqual(obj.value, 42)
self.assertNumQueries(1, test)
def test(): with self.assertNumQueries(1):
self.assertEqual(obj.text, "xyzzy") self.assertEqual(obj.text, "xyzzy")
self.assertNumQueries(1, test)
def test(): with self.assertNumQueries(0):
self.assertEqual(obj.text, "xyzzy") self.assertEqual(obj.text, "xyzzy")
self.assertNumQueries(0, test)
# Regression test for #10695. Make sure different instances don't # Regression test for #10695. Make sure different instances don't
# inadvertently share data in the deferred descriptor objects. # inadvertently share data in the deferred descriptor objects.

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import absolute_import from __future__ import with_statement, absolute_import
import datetime import datetime
@ -28,11 +28,10 @@ class TestTicket12510(TestCase):
self.groups = [Group.objects.create(name=name) for name in 'abc'] self.groups = [Group.objects.create(name=name) for name in 'abc']
def test_choices_not_fetched_when_not_rendering(self): def test_choices_not_fetched_when_not_rendering(self):
def test(): # only one query is required to pull the model from DB
with self.assertNumQueries(1):
field = ModelChoiceField(Group.objects.order_by('-name')) field = ModelChoiceField(Group.objects.order_by('-name'))
self.assertEqual('a', field.clean(self.groups[0].pk).name) self.assertEqual('a', field.clean(self.groups[0].pk).name)
# only one query is required to pull the model from DB
self.assertNumQueries(1, test)
class ModelFormCallableModelDefault(TestCase): class ModelFormCallableModelDefault(TestCase):
def test_no_empty_option(self): def test_no_empty_option(self):

View File

@ -1,10 +1,11 @@
from __future__ import absolute_import from __future__ import with_statement, absolute_import
from django.test import TestCase from django.test import TestCase
from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
AdvancedUserStat, Image, Product) AdvancedUserStat, Image, Product)
class ReverseSelectRelatedTestCase(TestCase): class ReverseSelectRelatedTestCase(TestCase):
def setUp(self): def setUp(self):
user = User.objects.create(username="test") user = User.objects.create(username="test")
@ -22,65 +23,56 @@ class ReverseSelectRelatedTestCase(TestCase):
StatDetails.objects.create(base_stats=advstat, comments=250) StatDetails.objects.create(base_stats=advstat, comments=250)
def test_basic(self): def test_basic(self):
def test(): with self.assertNumQueries(1):
u = User.objects.select_related("userprofile").get(username="test") u = User.objects.select_related("userprofile").get(username="test")
self.assertEqual(u.userprofile.state, "KS") self.assertEqual(u.userprofile.state, "KS")
self.assertNumQueries(1, test)
def test_follow_next_level(self): def test_follow_next_level(self):
def test(): with self.assertNumQueries(1):
u = User.objects.select_related("userstat__results").get(username="test") u = User.objects.select_related("userstat__results").get(username="test")
self.assertEqual(u.userstat.posts, 150) self.assertEqual(u.userstat.posts, 150)
self.assertEqual(u.userstat.results.results, 'first results') self.assertEqual(u.userstat.results.results, 'first results')
self.assertNumQueries(1, test)
def test_follow_two(self): def test_follow_two(self):
def test(): with self.assertNumQueries(1):
u = User.objects.select_related("userprofile", "userstat").get(username="test") u = User.objects.select_related("userprofile", "userstat").get(username="test")
self.assertEqual(u.userprofile.state, "KS") self.assertEqual(u.userprofile.state, "KS")
self.assertEqual(u.userstat.posts, 150) self.assertEqual(u.userstat.posts, 150)
self.assertNumQueries(1, test)
def test_follow_two_next_level(self): def test_follow_two_next_level(self):
def test(): with self.assertNumQueries(1):
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
self.assertEqual(u.userstat.results.results, 'first results') self.assertEqual(u.userstat.results.results, 'first results')
self.assertEqual(u.userstat.statdetails.comments, 259) self.assertEqual(u.userstat.statdetails.comments, 259)
self.assertNumQueries(1, test)
def test_forward_and_back(self): def test_forward_and_back(self):
def test(): with self.assertNumQueries(1):
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
self.assertEqual(stat.user.userprofile.state, 'KS') self.assertEqual(stat.user.userprofile.state, 'KS')
self.assertEqual(stat.user.userstat.posts, 150) self.assertEqual(stat.user.userstat.posts, 150)
self.assertNumQueries(1, test)
def test_back_and_forward(self): def test_back_and_forward(self):
def test(): with self.assertNumQueries(1):
u = User.objects.select_related("userstat").get(username="test") u = User.objects.select_related("userstat").get(username="test")
self.assertEqual(u.userstat.user.username, 'test') self.assertEqual(u.userstat.user.username, 'test')
self.assertNumQueries(1, test)
def test_not_followed_by_default(self): def test_not_followed_by_default(self):
def test(): with self.assertNumQueries(2):
u = User.objects.select_related().get(username="test") u = User.objects.select_related().get(username="test")
self.assertEqual(u.userstat.posts, 150) self.assertEqual(u.userstat.posts, 150)
self.assertNumQueries(2, test)
def test_follow_from_child_class(self): def test_follow_from_child_class(self):
def test(): with self.assertNumQueries(1):
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
self.assertEqual(stat.statdetails.comments, 250) self.assertEqual(stat.statdetails.comments, 250)
self.assertEqual(stat.user.username, 'bob') self.assertEqual(stat.user.username, 'bob')
self.assertNumQueries(1, test)
def test_follow_inheritance(self): def test_follow_inheritance(self):
def test(): with self.assertNumQueries(1):
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200) self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.user.username, 'bob') self.assertEqual(stat.user.username, 'bob')
self.assertEqual(stat.advanceduserstat.user.username, 'bob') self.assertEqual(stat.advanceduserstat.user.username, 'bob')
self.assertNumQueries(1, test)
def test_nullable_relation(self): def test_nullable_relation(self):
im = Image.objects.create(name="imag1") im = Image.objects.create(name="imag1")