Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2010-10-12 03:33:19 +00:00
parent ceef628c19
commit 5506653b77
13 changed files with 253 additions and 181 deletions

View File

@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local):
self.settings_dict = settings_dict self.settings_dict = settings_dict
self.alias = alias self.alias = alias
self.vendor = 'unknown' self.vendor = 'unknown'
self.use_debug_cursor = None
def __eq__(self, other): def __eq__(self, other):
return self.settings_dict == other.settings_dict return self.settings_dict == other.settings_dict
@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local):
def cursor(self): def cursor(self):
from django.conf import settings from django.conf import settings
cursor = self._cursor() cursor = self._cursor()
if settings.DEBUG: if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
return self.make_debug_cursor(cursor) return self.make_debug_cursor(cursor)
return cursor return cursor

View File

@ -1,4 +1,5 @@
import re import re
import sys
from urlparse import urlsplit, urlunsplit from urlparse import urlsplit, urlunsplit
from xml.dom.minidom import parseString, Node from xml.dom.minidom import parseString, Node
@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner):
for conn in connections: for conn in connections:
transaction.rollback_unless_managed(using=conn) transaction.rollback_unless_managed(using=conn)
class _AssertNumQueriesContext(object):
def __init__(self, test_case, num, connection):
self.test_case = test_case
self.num = num
self.connection = connection
def __enter__(self):
self.old_debug_cursor = self.connection.use_debug_cursor
self.connection.use_debug_cursor = True
self.starting_queries = len(self.connection.queries)
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
return
self.connection.use_debug_cursor = self.old_debug_cursor
final_queries = len(self.connection.queries)
executed = final_queries - self.starting_queries
self.test_case.assertEqual(
executed, self.num, "%d queries executed, %d expected" % (
executed, self.num
)
)
class TransactionTestCase(unittest.TestCase): class TransactionTestCase(unittest.TestCase):
# The class we'll use for the test client self.client. # The class we'll use for the test client self.client.
# Can be overridden in derived classes. # Can be overridden in derived classes.
@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase):
def assertQuerysetEqual(self, qs, values, transform=repr): def assertQuerysetEqual(self, qs, values, transform=repr):
return self.assertEqual(map(transform, qs), values) return self.assertEqual(map(transform, qs), values)
def assertNumQueries(self, num, func=None, *args, **kwargs):
using = kwargs.pop("using", DEFAULT_DB_ALIAS)
connection = connections[using]
context = _AssertNumQueriesContext(self, num, connection)
if func is None:
return context
# Basically emulate the `with` statement here.
context.__enter__()
try:
func(*args, **kwargs)
finally:
context.__exit__(*sys.exc_info())
def connections_support_transactions(): def connections_support_transactions():
""" """
Returns True if all connections support transactions. This is messy Returns True if all connections support transactions. This is messy

View File

@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
implicit ordering, you will need to apply a ``order_by()`` clause to your implicit ordering, you will need to apply a ``order_by()`` clause to your
queryset to ensure that the test will pass reliably. queryset to ensure that the test will pass reliably.
.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):
.. versionadded:: 1.3
Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
``num`` database queries are executed.
If a ``"using"`` key is present in ``kwargs`` it is used as the database
alias for which to check the number of queries. If you wish to call a
function with a ``using`` parameter you can do it by wrapping the call with
a ``lambda`` to add an extra parameter::
self.assertNumQueries(7, lambda: my_function(using=7))
If you're using Python 2.5 or greater you can also use this as a context
manager::
# This is necessary in Python 2.5 to enable the with statement, in 2.6
# and up it is no longer necessary.
from __future__ import with_statement
with self.assertNumQueries(2):
Person.objects.create(name="Aaron")
Person.objects.create(name="Daniel")
.. _topics-testing-email: .. _topics-testing-email:
E-mail services E-mail services

View File

@ -1,6 +1,4 @@
from django.test import TestCase from django.test import TestCase
from django.conf import settings
from django import db
from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
@ -36,73 +34,73 @@ class SelectRelatedTests(TestCase):
# queries so we'll set it to True here and reset it at the end of the # queries so we'll set it to True here and reset it at the end of the
# test case. # test case.
self.create_base_data() self.create_base_data()
settings.DEBUG = True
db.reset_queries()
def tearDown(self):
settings.DEBUG = False
def test_access_fks_without_select_related(self): def test_access_fks_without_select_related(self):
""" """
Normally, accessing FKs doesn't fill in related objects Normally, accessing FKs doesn't fill in related objects
""" """
fly = Species.objects.get(name="melanogaster") def test():
domain = fly.genus.family.order.klass.phylum.kingdom.domain fly = Species.objects.get(name="melanogaster")
self.assertEqual(domain.name, 'Eukaryota') domain = fly.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(len(db.connection.queries), 8) self.assertEqual(domain.name, 'Eukaryota')
self.assertNumQueries(8, test)
def test_access_fks_with_select_related(self): def test_access_fks_with_select_related(self):
""" """
A select_related() call will fill in those related objects without any A select_related() call will fill in those related objects without any
extra queries extra queries
""" """
person = Species.objects.select_related(depth=10).get(name="sapiens") def test():
domain = person.genus.family.order.klass.phylum.kingdom.domain person = Species.objects.select_related(depth=10).get(name="sapiens")
self.assertEqual(domain.name, 'Eukaryota') domain = person.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(len(db.connection.queries), 1) self.assertEqual(domain.name, 'Eukaryota')
self.assertNumQueries(1, test)
def test_list_without_select_related(self): def test_list_without_select_related(self):
""" """
select_related() also of course applies to entire lists, not just select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior without select_related. items. This test verifies the expected behavior without select_related.
""" """
world = Species.objects.all() def test():
families = [o.genus.family.name for o in world] world = Species.objects.all()
self.assertEqual(families, [ families = [o.genus.family.name for o in world]
'Drosophilidae', self.assertEqual(families, [
'Hominidae', 'Drosophilidae',
'Fabaceae', 'Hominidae',
'Amanitacae', 'Fabaceae',
]) 'Amanitacae',
self.assertEqual(len(db.connection.queries), 9) ])
self.assertNumQueries(9, test)
def test_list_with_select_related(self): def test_list_with_select_related(self):
""" """
select_related() also of course applies to entire lists, not just select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior with select_related. items. This test verifies the expected behavior with select_related.
""" """
world = Species.objects.all().select_related() def test():
families = [o.genus.family.name for o in world] world = Species.objects.all().select_related()
self.assertEqual(families, [ families = [o.genus.family.name for o in world]
'Drosophilidae', self.assertEqual(families, [
'Hominidae', 'Drosophilidae',
'Fabaceae', 'Hominidae',
'Amanitacae', 'Fabaceae',
]) 'Amanitacae',
self.assertEqual(len(db.connection.queries), 1) ])
self.assertNumQueries(1, test)
def test_depth(self, depth=1, expected=7): def test_depth(self, depth=1, expected=7):
""" """
The "depth" argument to select_related() will stop the descent at a The "depth" argument to select_related() will stop the descent at a
particular level. particular level.
""" """
pea = Species.objects.select_related(depth=depth).get(name="sativum") def test():
self.assertEqual( pea = Species.objects.select_related(depth=depth).get(name="sativum")
pea.genus.family.order.klass.phylum.kingdom.domain.name, self.assertEqual(
'Eukaryota' pea.genus.family.order.klass.phylum.kingdom.domain.name,
) 'Eukaryota'
)
# Notice: one fewer queries than above because of depth=1 # Notice: one fewer queries than above because of depth=1
self.assertEqual(len(db.connection.queries), expected) self.assertNumQueries(expected, test)
def test_larger_depth(self): def test_larger_depth(self):
""" """
@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase):
The "depth" argument to select_related() will stop the descent at a The "depth" argument to select_related() will stop the descent at a
particular level. This can be used on lists as well. particular level. This can be used on lists as well.
""" """
world = Species.objects.all().select_related(depth=2) def test():
orders = [o.genus.family.order.name for o in world] world = Species.objects.all().select_related(depth=2)
self.assertEqual(orders, orders = [o.genus.family.order.name for o in world]
['Diptera', 'Primates', 'Fabales', 'Agaricales']) self.assertEqual(orders,
self.assertEqual(len(db.connection.queries), 5) ['Diptera', 'Primates', 'Fabales', 'Agaricales'])
self.assertNumQueries(5, test)
def test_select_related_with_extra(self): def test_select_related_with_extra(self):
s = Species.objects.all().select_related(depth=1)\ s = Species.objects.all().select_related(depth=1)\
@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase):
In this case, we explicitly say to select the 'genus' and In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before. 'genus.family' models, leading to the same number of queries as before.
""" """
world = Species.objects.select_related('genus__family') def test():
families = [o.genus.family.name for o in world] world = Species.objects.select_related('genus__family')
self.assertEqual(families, families = [o.genus.family.name for o in world]
['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) self.assertEqual(families,
self.assertEqual(len(db.connection.queries), 1) ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
self.assertNumQueries(1, test)
def test_more_certain_fields(self): def test_more_certain_fields(self):
""" """
In this case, we explicitly say to select the 'genus' and In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before. 'genus.family' models, leading to the same number of queries as before.
""" """
world = Species.objects.filter(genus__name='Amanita')\ def test():
.select_related('genus__family') world = Species.objects.filter(genus__name='Amanita')\
orders = [o.genus.family.order.name for o in world] .select_related('genus__family')
self.assertEqual(orders, [u'Agaricales']) orders = [o.genus.family.order.name for o in world]
self.assertEqual(len(db.connection.queries), 2) self.assertEqual(orders, [u'Agaricales'])
self.assertNumQueries(2, test)
def test_field_traversal(self): def test_field_traversal(self):
s = Species.objects.all().select_related('genus__family__order' def test():
).order_by('id')[0:1].get().genus.family.order.name s = Species.objects.all().select_related('genus__family__order'
self.assertEqual(s, u'Diptera') ).order_by('id')[0:1].get().genus.family.order.name
self.assertEqual(len(db.connection.queries), 1) self.assertEqual(s, u'Diptera')
self.assertNumQueries(1, test)
def test_depth_fields_fails(self): def test_depth_fields_fails(self):
self.assertRaises(TypeError, self.assertRaises(TypeError,

View File

@ -2,9 +2,11 @@ import datetime
from django.conf import settings from django.conf import settings
from django.db import connection from django.db import connection
from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
UniqueForDateModel, ModelToValidate)
class GetUniqueCheckTests(unittest.TestCase): class GetUniqueCheckTests(unittest.TestCase):
@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase):
), m._get_unique_checks(exclude='start_date') ), m._get_unique_checks(exclude='start_date')
) )
class PerformUniqueChecksTest(unittest.TestCase): class PerformUniqueChecksTest(TestCase):
def setUp(self):
# Set debug to True to gain access to connection.queries.
self._old_debug, settings.DEBUG = settings.DEBUG, True
super(PerformUniqueChecksTest, self).setUp()
def tearDown(self):
# Restore old debug value.
settings.DEBUG = self._old_debug
super(PerformUniqueChecksTest, self).tearDown()
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self): def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
# Regression test for #12560 # Regression test for #12560
query_count = len(connection.queries) def test():
mtv = ModelToValidate(number=10, name='Some Name') mtv = ModelToValidate(number=10, name='Some Name')
setattr(mtv, '_adding', True) setattr(mtv, '_adding', True)
mtv.full_clean() mtv.full_clean()
self.assertEqual(query_count, len(connection.queries)) self.assertNumQueries(0, test)
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self): def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
# Regression test for #12560 # Regression test for #12560
query_count = len(connection.queries) def test():
mtv = ModelToValidate(number=10, name='Some Name', id=123) mtv = ModelToValidate(number=10, name='Some Name', id=123)
setattr(mtv, '_adding', True) setattr(mtv, '_adding', True)
mtv.full_clean() mtv.full_clean()
self.assertEqual(query_count + 1, len(connection.queries)) self.assertNumQueries(1, test)
def test_primary_key_unique_check_not_performed_when_not_adding(self): def test_primary_key_unique_check_not_performed_when_not_adding(self):
# Regression test for #12132 # Regression test for #12132
query_count= len(connection.queries) def test():
mtv = ModelToValidate(number=10, name='Some Name') mtv = ModelToValidate(number=10, name='Some Name')
mtv.full_clean() mtv.full_clean()
self.assertEqual(query_count, len(connection.queries)) self.assertNumQueries(0, test)

View File

@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate
# Import other tests for this package. # Import other tests for this package.
from modeltests.validation.validators import TestModelsWithValidators from modeltests.validation.validators import TestModelsWithValidators
from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest from modeltests.validation.test_unique import (GetUniqueCheckTests,
PerformUniqueChecksTest)
from modeltests.validation.test_custom_messages import CustomMessagesTest from modeltests.validation.test_custom_messages import CustomMessagesTest
@ -111,4 +112,3 @@ class ModelFormsTests(TestCase):
article = Article(author_id=self.author.id) article = Article(author_id=self.author.id)
form = ArticleForm(data, instance=article) form = ArticleForm(data, instance=article)
self.assertEqual(form.errors.keys(), ['pub_date']) self.assertEqual(form.errors.keys(), ['pub_date'])

View File

@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf
class DeferRegressionTest(TestCase): class DeferRegressionTest(TestCase):
def assert_num_queries(self, n, func, *args, **kwargs):
old_DEBUG = settings.DEBUG
settings.DEBUG = True
starting_queries = len(connection.queries)
try:
func(*args, **kwargs)
finally:
settings.DEBUG = old_DEBUG
self.assertEqual(starting_queries + n, len(connection.queries))
def test_basic(self): def test_basic(self):
# Deferred fields should really be deferred and not accidentally use # Deferred fields should really be deferred and not accidentally use
# the field's default value just because they aren't passed to __init__ # the field's default value just because they aren't passed to __init__
@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase):
def test(): def test():
self.assertEqual(obj.name, "first") self.assertEqual(obj.name, "first")
self.assertEqual(obj.other_value, 0) self.assertEqual(obj.other_value, 0)
self.assert_num_queries(0, test) self.assertNumQueries(0, test)
def test(): def test():
self.assertEqual(obj.value, 42) self.assertEqual(obj.value, 42)
self.assert_num_queries(1, test) self.assertNumQueries(1, test)
def test(): def test():
self.assertEqual(obj.text, "xyzzy") self.assertEqual(obj.text, "xyzzy")
self.assert_num_queries(1, test) self.assertNumQueries(1, test)
def test(): def test():
self.assertEqual(obj.text, "xyzzy") self.assertEqual(obj.text, "xyzzy")
self.assert_num_queries(0, test) self.assertNumQueries(0, test)
# Regression test for #10695. Make sure different instances don't # Regression test for #10695. Make sure different instances don't
# inadvertently share data in the deferred descriptor objects. # inadvertently share data in the deferred descriptor objects.

View File

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime import datetime
import tempfile
import shutil import shutil
import tempfile
from django.db import models, connection from django.db import models
from django.conf import settings
# Can't import as "forms" due to implementation details in the test suite (the # Can't import as "forms" due to implementation details in the test suite (the
# current file is called "forms" and is already imported). # current file is called "forms" and is already imported).
from django import forms as django_forms from django import forms as django_forms
@ -77,19 +76,13 @@ class TestTicket12510(TestCase):
''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). ''' ''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). '''
def setUp(self): def setUp(self):
self.groups = [Group.objects.create(name=name) for name in 'abc'] self.groups = [Group.objects.create(name=name) for name in 'abc']
self.old_debug = settings.DEBUG
# turn debug on to get access to connection.queries
settings.DEBUG = True
def tearDown(self):
settings.DEBUG = self.old_debug
def test_choices_not_fetched_when_not_rendering(self): def test_choices_not_fetched_when_not_rendering(self):
initial_queries = len(connection.queries) def test():
field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
self.assertEqual('a', field.clean(self.groups[0].pk).name) self.assertEqual('a', field.clean(self.groups[0].pk).name)
# only one query is required to pull the model from DB # only one query is required to pull the model from DB
self.assertEqual(initial_queries+1, len(connection.queries)) self.assertNumQueries(1, test)
class ModelFormCallableModelDefault(TestCase): class ModelFormCallableModelDefault(TestCase):
def test_no_empty_option(self): def test_no_empty_option(self):

View File

@ -1,10 +1,8 @@
import unittest import unittest
from datetime import date from datetime import date
from django import db
from django import forms from django import forms
from django.forms.models import modelform_factory, ModelChoiceField from django.forms.models import modelform_factory, ModelChoiceField
from django.conf import settings
from django.test import TestCase from django.test import TestCase
from django.core.exceptions import FieldError, ValidationError from django.core.exceptions import FieldError, ValidationError
from django.core.files.uploadedfile import SimpleUploadedFile from django.core.files.uploadedfile import SimpleUploadedFile
@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \
class ModelMultipleChoiceFieldTests(TestCase): class ModelMultipleChoiceFieldTests(TestCase):
def setUp(self):
self.old_debug = settings.DEBUG
settings.DEBUG = True
def tearDown(self):
settings.DEBUG = self.old_debug
def test_model_multiple_choice_number_of_queries(self): def test_model_multiple_choice_number_of_queries(self):
""" """
Test that ModelMultipleChoiceField does O(1) queries instead of Test that ModelMultipleChoiceField does O(1) queries instead of
@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase):
for i in range(30): for i in range(30):
Person.objects.create(name="Person %s" % i) Person.objects.create(name="Person %s" % i)
db.reset_queries()
f = forms.ModelMultipleChoiceField(queryset=Person.objects.all()) f = forms.ModelMultipleChoiceField(queryset=Person.objects.all())
selected = f.clean([1, 3, 5, 7, 9]) self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9])
self.assertEquals(len(db.connection.queries), 1)
class TripleForm(forms.ModelForm): class TripleForm(forms.ModelForm):
class Meta: class Meta:

View File

@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
class ReverseSelectRelatedTestCase(TestCase): class ReverseSelectRelatedTestCase(TestCase):
def setUp(self): def setUp(self):
# Explicitly enable debug for these tests - we need to count
# the queries that have been issued.
self.old_debug = settings.DEBUG
settings.DEBUG = True
user = User.objects.create(username="test") user = User.objects.create(username="test")
userprofile = UserProfile.objects.create(user=user, state="KS", userprofile = UserProfile.objects.create(user=user, state="KS",
city="Lawrence") city="Lawrence")
@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase):
results=results2) results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250) StatDetails.objects.create(base_stats=advstat, comments=250)
db.reset_queries()
def assertQueries(self, queries):
self.assertEqual(len(db.connection.queries), queries)
def tearDown(self):
settings.DEBUG = self.old_debug
def test_basic(self): def test_basic(self):
u = User.objects.select_related("userprofile").get(username="test") def test():
self.assertEqual(u.userprofile.state, "KS") u = User.objects.select_related("userprofile").get(username="test")
self.assertQueries(1) self.assertEqual(u.userprofile.state, "KS")
self.assertNumQueries(1, test)
def test_follow_next_level(self): def test_follow_next_level(self):
u = User.objects.select_related("userstat__results").get(username="test") def test():
self.assertEqual(u.userstat.posts, 150) u = User.objects.select_related("userstat__results").get(username="test")
self.assertEqual(u.userstat.results.results, 'first results') self.assertEqual(u.userstat.posts, 150)
self.assertQueries(1) self.assertEqual(u.userstat.results.results, 'first results')
self.assertNumQueries(1, test)
def test_follow_two(self): def test_follow_two(self):
u = User.objects.select_related("userprofile", "userstat").get(username="test") def test():
self.assertEqual(u.userprofile.state, "KS") u = User.objects.select_related("userprofile", "userstat").get(username="test")
self.assertEqual(u.userstat.posts, 150) self.assertEqual(u.userprofile.state, "KS")
self.assertQueries(1) self.assertEqual(u.userstat.posts, 150)
self.assertNumQueries(1, test)
def test_follow_two_next_level(self): def test_follow_two_next_level(self):
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") def test():
self.assertEqual(u.userstat.results.results, 'first results') u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
self.assertEqual(u.userstat.statdetails.comments, 259) self.assertEqual(u.userstat.results.results, 'first results')
self.assertQueries(1) self.assertEqual(u.userstat.statdetails.comments, 259)
self.assertNumQueries(1, test)
def test_forward_and_back(self): def test_forward_and_back(self):
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") def test():
self.assertEqual(stat.user.userprofile.state, 'KS') stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
self.assertEqual(stat.user.userstat.posts, 150) self.assertEqual(stat.user.userprofile.state, 'KS')
self.assertQueries(1) self.assertEqual(stat.user.userstat.posts, 150)
self.assertNumQueries(1, test)
def test_back_and_forward(self): def test_back_and_forward(self):
u = User.objects.select_related("userstat").get(username="test") def test():
self.assertEqual(u.userstat.user.username, 'test') u = User.objects.select_related("userstat").get(username="test")
self.assertQueries(1) self.assertEqual(u.userstat.user.username, 'test')
self.assertNumQueries(1, test)
def test_not_followed_by_default(self): def test_not_followed_by_default(self):
u = User.objects.select_related().get(username="test") def test():
self.assertEqual(u.userstat.posts, 150) u = User.objects.select_related().get(username="test")
self.assertQueries(2) self.assertEqual(u.userstat.posts, 150)
self.assertNumQueries(2, test)
def test_follow_from_child_class(self): def test_follow_from_child_class(self):
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) def test():
self.assertEqual(stat.statdetails.comments, 250) stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
self.assertEqual(stat.user.username, 'bob') self.assertEqual(stat.statdetails.comments, 250)
self.assertQueries(1) self.assertEqual(stat.user.username, 'bob')
self.assertNumQueries(1, test)
def test_follow_inheritance(self): def test_follow_inheritance(self):
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) def test():
self.assertEqual(stat.advanceduserstat.posts, 200) stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.user.username, 'bob') self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.advanceduserstat.user.username, 'bob') self.assertEqual(stat.user.username, 'bob')
self.assertQueries(1) self.assertEqual(stat.advanceduserstat.user.username, 'bob')
self.assertNumQueries(1, test)
def test_nullable_relation(self): def test_nullable_relation(self):
im = Image.objects.create(name="imag1") im = Image.objects.create(name="imag1")

View File

@ -0,0 +1,5 @@
from django.db import models
class Person(models.Model):
name = models.CharField(max_length=100)

View File

@ -0,0 +1,30 @@
from __future__ import with_statement
from django.test import TestCase
from models import Person
class AssertNumQueriesTests(TestCase):
def test_simple(self):
with self.assertNumQueries(0):
pass
with self.assertNumQueries(1):
# Guy who wrote Linux
Person.objects.create(name="Linus Torvalds")
with self.assertNumQueries(2):
# Guy who owns the bagel place I like
Person.objects.create(name="Uncle Ricky")
self.assertEqual(Person.objects.count(), 2)
def test_failure(self):
with self.assertRaises(AssertionError) as exc_info:
with self.assertNumQueries(2):
Person.objects.count()
self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected")
with self.assertRaises(TypeError):
with self.assertNumQueries(4000):
raise TypeError

View File

@ -1,4 +1,10 @@
r""" import sys
if sys.version_info >= (2, 5):
from python_25 import AssertNumQueriesTests
__test__ = {"API_TEST": r"""
# Some checks of the doctest output normalizer. # Some checks of the doctest output normalizer.
# Standard doctests do fairly # Standard doctests do fairly
>>> from django.utils import simplejson >>> from django.utils import simplejson
@ -69,4 +75,4 @@ r"""
>>> produce_xml_fragment() >>> produce_xml_fragment()
'<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>' '<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>'
""" """}