Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2010-10-12 03:33:19 +00:00
parent ceef628c19
commit 5506653b77
13 changed files with 253 additions and 181 deletions

View File

@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local):
self.settings_dict = settings_dict
self.alias = alias
self.vendor = 'unknown'
self.use_debug_cursor = None
def __eq__(self, other):
return self.settings_dict == other.settings_dict
@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local):
def cursor(self):
from django.conf import settings
cursor = self._cursor()
if settings.DEBUG:
if (self.use_debug_cursor or
(self.use_debug_cursor is None and settings.DEBUG)):
return self.make_debug_cursor(cursor)
return cursor

View File

@ -1,4 +1,5 @@
import re
import sys
from urlparse import urlsplit, urlunsplit
from xml.dom.minidom import parseString, Node
@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner):
for conn in connections:
transaction.rollback_unless_managed(using=conn)
class _AssertNumQueriesContext(object):
def __init__(self, test_case, num, connection):
self.test_case = test_case
self.num = num
self.connection = connection
def __enter__(self):
self.old_debug_cursor = self.connection.use_debug_cursor
self.connection.use_debug_cursor = True
self.starting_queries = len(self.connection.queries)
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
return
self.connection.use_debug_cursor = self.old_debug_cursor
final_queries = len(self.connection.queries)
executed = final_queries - self.starting_queries
self.test_case.assertEqual(
executed, self.num, "%d queries executed, %d expected" % (
executed, self.num
)
)
class TransactionTestCase(unittest.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase):
def assertQuerysetEqual(self, qs, values, transform=repr):
return self.assertEqual(map(transform, qs), values)
def assertNumQueries(self, num, func=None, *args, **kwargs):
using = kwargs.pop("using", DEFAULT_DB_ALIAS)
connection = connections[using]
context = _AssertNumQueriesContext(self, num, connection)
if func is None:
return context
# Basically emulate the `with` statement here.
context.__enter__()
try:
func(*args, **kwargs)
finally:
context.__exit__(*sys.exc_info())
def connections_support_transactions():
"""
Returns True if all connections support transactions. This is messy

View File

@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
implicit ordering, you will need to apply a ``order_by()`` clause to your
queryset to ensure that the test will pass reliably.
.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):
.. versionadded:: 1.3
Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
``num`` database queries are executed.
If a ``"using"`` key is present in ``kwargs`` it is used as the database
alias for which to check the number of queries. If you wish to call a
function with a ``using`` parameter you can do it by wrapping the call with
a ``lambda`` to add an extra parameter::
self.assertNumQueries(7, lambda: my_function(using=7))
If you're using Python 2.5 or greater you can also use this as a context
manager::
# This is necessary in Python 2.5 to enable the with statement, in 2.6
# and up it is no longer necessary.
from __future__ import with_statement
with self.assertNumQueries(2):
Person.objects.create(name="Aaron")
Person.objects.create(name="Daniel")
.. _topics-testing-email:
E-mail services

View File

@ -1,6 +1,4 @@
from django.test import TestCase
from django.conf import settings
from django import db
from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
@ -36,73 +34,73 @@ class SelectRelatedTests(TestCase):
# queries so we'll set it to True here and reset it at the end of the
# test case.
self.create_base_data()
settings.DEBUG = True
db.reset_queries()
def tearDown(self):
settings.DEBUG = False
def test_access_fks_without_select_related(self):
"""
Normally, accessing FKs doesn't fill in related objects
"""
fly = Species.objects.get(name="melanogaster")
domain = fly.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(domain.name, 'Eukaryota')
self.assertEqual(len(db.connection.queries), 8)
def test():
fly = Species.objects.get(name="melanogaster")
domain = fly.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(domain.name, 'Eukaryota')
self.assertNumQueries(8, test)
def test_access_fks_with_select_related(self):
"""
A select_related() call will fill in those related objects without any
extra queries
"""
person = Species.objects.select_related(depth=10).get(name="sapiens")
domain = person.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(domain.name, 'Eukaryota')
self.assertEqual(len(db.connection.queries), 1)
def test():
person = Species.objects.select_related(depth=10).get(name="sapiens")
domain = person.genus.family.order.klass.phylum.kingdom.domain
self.assertEqual(domain.name, 'Eukaryota')
self.assertNumQueries(1, test)
def test_list_without_select_related(self):
"""
select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior without select_related.
"""
world = Species.objects.all()
families = [o.genus.family.name for o in world]
self.assertEqual(families, [
'Drosophilidae',
'Hominidae',
'Fabaceae',
'Amanitacae',
])
self.assertEqual(len(db.connection.queries), 9)
def test():
world = Species.objects.all()
families = [o.genus.family.name for o in world]
self.assertEqual(families, [
'Drosophilidae',
'Hominidae',
'Fabaceae',
'Amanitacae',
])
self.assertNumQueries(9, test)
def test_list_with_select_related(self):
"""
select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior with select_related.
"""
world = Species.objects.all().select_related()
families = [o.genus.family.name for o in world]
self.assertEqual(families, [
'Drosophilidae',
'Hominidae',
'Fabaceae',
'Amanitacae',
])
self.assertEqual(len(db.connection.queries), 1)
def test():
world = Species.objects.all().select_related()
families = [o.genus.family.name for o in world]
self.assertEqual(families, [
'Drosophilidae',
'Hominidae',
'Fabaceae',
'Amanitacae',
])
self.assertNumQueries(1, test)
def test_depth(self, depth=1, expected=7):
"""
The "depth" argument to select_related() will stop the descent at a
particular level.
"""
pea = Species.objects.select_related(depth=depth).get(name="sativum")
self.assertEqual(
pea.genus.family.order.klass.phylum.kingdom.domain.name,
'Eukaryota'
)
def test():
pea = Species.objects.select_related(depth=depth).get(name="sativum")
self.assertEqual(
pea.genus.family.order.klass.phylum.kingdom.domain.name,
'Eukaryota'
)
# Notice: one fewer queries than above because of depth=1
self.assertEqual(len(db.connection.queries), expected)
self.assertNumQueries(expected, test)
def test_larger_depth(self):
"""
@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase):
The "depth" argument to select_related() will stop the descent at a
particular level. This can be used on lists as well.
"""
world = Species.objects.all().select_related(depth=2)
orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders,
['Diptera', 'Primates', 'Fabales', 'Agaricales'])
self.assertEqual(len(db.connection.queries), 5)
def test():
world = Species.objects.all().select_related(depth=2)
orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders,
['Diptera', 'Primates', 'Fabales', 'Agaricales'])
self.assertNumQueries(5, test)
def test_select_related_with_extra(self):
s = Species.objects.all().select_related(depth=1)\
@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase):
In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before.
"""
world = Species.objects.select_related('genus__family')
families = [o.genus.family.name for o in world]
self.assertEqual(families,
['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
self.assertEqual(len(db.connection.queries), 1)
def test():
world = Species.objects.select_related('genus__family')
families = [o.genus.family.name for o in world]
self.assertEqual(families,
['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
self.assertNumQueries(1, test)
def test_more_certain_fields(self):
"""
In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before.
"""
world = Species.objects.filter(genus__name='Amanita')\
.select_related('genus__family')
orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders, [u'Agaricales'])
self.assertEqual(len(db.connection.queries), 2)
def test():
world = Species.objects.filter(genus__name='Amanita')\
.select_related('genus__family')
orders = [o.genus.family.order.name for o in world]
self.assertEqual(orders, [u'Agaricales'])
self.assertNumQueries(2, test)
def test_field_traversal(self):
s = Species.objects.all().select_related('genus__family__order'
).order_by('id')[0:1].get().genus.family.order.name
self.assertEqual(s, u'Diptera')
self.assertEqual(len(db.connection.queries), 1)
def test():
s = Species.objects.all().select_related('genus__family__order'
).order_by('id')[0:1].get().genus.family.order.name
self.assertEqual(s, u'Diptera')
self.assertNumQueries(1, test)
def test_depth_fields_fails(self):
self.assertRaises(TypeError,

View File

@ -2,9 +2,11 @@ import datetime
from django.conf import settings
from django.db import connection
from django.test import TestCase
from django.utils import unittest
from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate
from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
UniqueForDateModel, ModelToValidate)
class GetUniqueCheckTests(unittest.TestCase):
@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase):
), m._get_unique_checks(exclude='start_date')
)
class PerformUniqueChecksTest(unittest.TestCase):
def setUp(self):
# Set debug to True to gain access to connection.queries.
self._old_debug, settings.DEBUG = settings.DEBUG, True
super(PerformUniqueChecksTest, self).setUp()
def tearDown(self):
# Restore old debug value.
settings.DEBUG = self._old_debug
super(PerformUniqueChecksTest, self).tearDown()
class PerformUniqueChecksTest(TestCase):
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
# Regression test for #12560
query_count = len(connection.queries)
mtv = ModelToValidate(number=10, name='Some Name')
setattr(mtv, '_adding', True)
mtv.full_clean()
self.assertEqual(query_count, len(connection.queries))
def test():
mtv = ModelToValidate(number=10, name='Some Name')
setattr(mtv, '_adding', True)
mtv.full_clean()
self.assertNumQueries(0, test)
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
# Regression test for #12560
query_count = len(connection.queries)
mtv = ModelToValidate(number=10, name='Some Name', id=123)
setattr(mtv, '_adding', True)
mtv.full_clean()
self.assertEqual(query_count + 1, len(connection.queries))
def test():
mtv = ModelToValidate(number=10, name='Some Name', id=123)
setattr(mtv, '_adding', True)
mtv.full_clean()
self.assertNumQueries(1, test)
def test_primary_key_unique_check_not_performed_when_not_adding(self):
# Regression test for #12132
query_count= len(connection.queries)
mtv = ModelToValidate(number=10, name='Some Name')
mtv.full_clean()
self.assertEqual(query_count, len(connection.queries))
def test():
mtv = ModelToValidate(number=10, name='Some Name')
mtv.full_clean()
self.assertNumQueries(0, test)

View File

@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate
# Import other tests for this package.
from modeltests.validation.validators import TestModelsWithValidators
from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest
from modeltests.validation.test_unique import (GetUniqueCheckTests,
PerformUniqueChecksTest)
from modeltests.validation.test_custom_messages import CustomMessagesTest
@ -111,4 +112,3 @@ class ModelFormsTests(TestCase):
article = Article(author_id=self.author.id)
form = ArticleForm(data, instance=article)
self.assertEqual(form.errors.keys(), ['pub_date'])

View File

@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf
class DeferRegressionTest(TestCase):
def assert_num_queries(self, n, func, *args, **kwargs):
old_DEBUG = settings.DEBUG
settings.DEBUG = True
starting_queries = len(connection.queries)
try:
func(*args, **kwargs)
finally:
settings.DEBUG = old_DEBUG
self.assertEqual(starting_queries + n, len(connection.queries))
def test_basic(self):
# Deferred fields should really be deferred and not accidentally use
# the field's default value just because they aren't passed to __init__
@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase):
def test():
self.assertEqual(obj.name, "first")
self.assertEqual(obj.other_value, 0)
self.assert_num_queries(0, test)
self.assertNumQueries(0, test)
def test():
self.assertEqual(obj.value, 42)
self.assert_num_queries(1, test)
self.assertNumQueries(1, test)
def test():
self.assertEqual(obj.text, "xyzzy")
self.assert_num_queries(1, test)
self.assertNumQueries(1, test)
def test():
self.assertEqual(obj.text, "xyzzy")
self.assert_num_queries(0, test)
self.assertNumQueries(0, test)
# Regression test for #10695. Make sure different instances don't
# inadvertently share data in the deferred descriptor objects.

View File

@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
import datetime
import tempfile
import shutil
import tempfile
from django.db import models, connection
from django.conf import settings
from django.db import models
# Can't import as "forms" due to implementation details in the test suite (the
# current file is called "forms" and is already imported).
from django import forms as django_forms
@ -77,19 +76,13 @@ class TestTicket12510(TestCase):
''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). '''
def setUp(self):
self.groups = [Group.objects.create(name=name) for name in 'abc']
self.old_debug = settings.DEBUG
# turn debug on to get access to connection.queries
settings.DEBUG = True
def tearDown(self):
settings.DEBUG = self.old_debug
def test_choices_not_fetched_when_not_rendering(self):
initial_queries = len(connection.queries)
field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
self.assertEqual('a', field.clean(self.groups[0].pk).name)
def test():
field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
self.assertEqual('a', field.clean(self.groups[0].pk).name)
# only one query is required to pull the model from DB
self.assertEqual(initial_queries+1, len(connection.queries))
self.assertNumQueries(1, test)
class ModelFormCallableModelDefault(TestCase):
def test_no_empty_option(self):

View File

@ -1,10 +1,8 @@
import unittest
from datetime import date
from django import db
from django import forms
from django.forms.models import modelform_factory, ModelChoiceField
from django.conf import settings
from django.test import TestCase
from django.core.exceptions import FieldError, ValidationError
from django.core.files.uploadedfile import SimpleUploadedFile
@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \
class ModelMultipleChoiceFieldTests(TestCase):
def setUp(self):
self.old_debug = settings.DEBUG
settings.DEBUG = True
def tearDown(self):
settings.DEBUG = self.old_debug
def test_model_multiple_choice_number_of_queries(self):
"""
Test that ModelMultipleChoiceField does O(1) queries instead of
@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase):
for i in range(30):
Person.objects.create(name="Person %s" % i)
db.reset_queries()
f = forms.ModelMultipleChoiceField(queryset=Person.objects.all())
selected = f.clean([1, 3, 5, 7, 9])
self.assertEquals(len(db.connection.queries), 1)
self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9])
class TripleForm(forms.ModelForm):
class Meta:
@ -312,7 +300,7 @@ class InvalidFieldAndFactory(TestCase):
model = Person
fields = ('name', 'no-field')
except FieldError, e:
# Make sure the exception contains some reference to the
# Make sure the exception contains some reference to the
# field responsible for the problem.
self.assertTrue('no-field' in e.args[0])
else:

View File

@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
class ReverseSelectRelatedTestCase(TestCase):
def setUp(self):
# Explicitly enable debug for these tests - we need to count
# the queries that have been issued.
self.old_debug = settings.DEBUG
settings.DEBUG = True
user = User.objects.create(username="test")
userprofile = UserProfile.objects.create(user=user, state="KS",
city="Lawrence")
@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase):
results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250)
db.reset_queries()
def assertQueries(self, queries):
self.assertEqual(len(db.connection.queries), queries)
def tearDown(self):
settings.DEBUG = self.old_debug
def test_basic(self):
u = User.objects.select_related("userprofile").get(username="test")
self.assertEqual(u.userprofile.state, "KS")
self.assertQueries(1)
def test():
u = User.objects.select_related("userprofile").get(username="test")
self.assertEqual(u.userprofile.state, "KS")
self.assertNumQueries(1, test)
def test_follow_next_level(self):
u = User.objects.select_related("userstat__results").get(username="test")
self.assertEqual(u.userstat.posts, 150)
self.assertEqual(u.userstat.results.results, 'first results')
self.assertQueries(1)
def test():
u = User.objects.select_related("userstat__results").get(username="test")
self.assertEqual(u.userstat.posts, 150)
self.assertEqual(u.userstat.results.results, 'first results')
self.assertNumQueries(1, test)
def test_follow_two(self):
u = User.objects.select_related("userprofile", "userstat").get(username="test")
self.assertEqual(u.userprofile.state, "KS")
self.assertEqual(u.userstat.posts, 150)
self.assertQueries(1)
def test():
u = User.objects.select_related("userprofile", "userstat").get(username="test")
self.assertEqual(u.userprofile.state, "KS")
self.assertEqual(u.userstat.posts, 150)
self.assertNumQueries(1, test)
def test_follow_two_next_level(self):
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
self.assertEqual(u.userstat.results.results, 'first results')
self.assertEqual(u.userstat.statdetails.comments, 259)
self.assertQueries(1)
def test():
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
self.assertEqual(u.userstat.results.results, 'first results')
self.assertEqual(u.userstat.statdetails.comments, 259)
self.assertNumQueries(1, test)
def test_forward_and_back(self):
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
self.assertEqual(stat.user.userprofile.state, 'KS')
self.assertEqual(stat.user.userstat.posts, 150)
self.assertQueries(1)
def test():
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
self.assertEqual(stat.user.userprofile.state, 'KS')
self.assertEqual(stat.user.userstat.posts, 150)
self.assertNumQueries(1, test)
def test_back_and_forward(self):
u = User.objects.select_related("userstat").get(username="test")
self.assertEqual(u.userstat.user.username, 'test')
self.assertQueries(1)
def test():
u = User.objects.select_related("userstat").get(username="test")
self.assertEqual(u.userstat.user.username, 'test')
self.assertNumQueries(1, test)
def test_not_followed_by_default(self):
u = User.objects.select_related().get(username="test")
self.assertEqual(u.userstat.posts, 150)
self.assertQueries(2)
def test():
u = User.objects.select_related().get(username="test")
self.assertEqual(u.userstat.posts, 150)
self.assertNumQueries(2, test)
def test_follow_from_child_class(self):
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
self.assertEqual(stat.statdetails.comments, 250)
self.assertEqual(stat.user.username, 'bob')
self.assertQueries(1)
def test():
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
self.assertEqual(stat.statdetails.comments, 250)
self.assertEqual(stat.user.username, 'bob')
self.assertNumQueries(1, test)
def test_follow_inheritance(self):
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.user.username, 'bob')
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
self.assertQueries(1)
def test():
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
self.assertEqual(stat.advanceduserstat.posts, 200)
self.assertEqual(stat.user.username, 'bob')
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
self.assertNumQueries(1, test)
def test_nullable_relation(self):
im = Image.objects.create(name="imag1")

View File

@ -0,0 +1,5 @@
from django.db import models
class Person(models.Model):
name = models.CharField(max_length=100)

View File

@ -0,0 +1,30 @@
from __future__ import with_statement
from django.test import TestCase
from models import Person
class AssertNumQueriesTests(TestCase):
def test_simple(self):
with self.assertNumQueries(0):
pass
with self.assertNumQueries(1):
# Guy who wrote Linux
Person.objects.create(name="Linus Torvalds")
with self.assertNumQueries(2):
# Guy who owns the bagel place I like
Person.objects.create(name="Uncle Ricky")
self.assertEqual(Person.objects.count(), 2)
def test_failure(self):
with self.assertRaises(AssertionError) as exc_info:
with self.assertNumQueries(2):
Person.objects.count()
self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected")
with self.assertRaises(TypeError):
with self.assertNumQueries(4000):
raise TypeError

View File

@ -1,6 +1,12 @@
r"""
import sys
if sys.version_info >= (2, 5):
from python_25 import AssertNumQueriesTests
__test__ = {"API_TEST": r"""
# Some checks of the doctest output normalizer.
# Standard doctests do fairly
# Standard doctests do fairly
>>> from django.utils import simplejson
>>> from django.utils.xmlutils import SimplerXMLGenerator
>>> from StringIO import StringIO
@ -55,7 +61,7 @@ r"""
>>> produce_json()
'["foo", {"whiz": 42, "bar": ["baz", null, 1.0, 2]}]'
# XML output is normalized for attribute order, so it doesn't matter
# XML output is normalized for attribute order, so it doesn't matter
# which order XML element attributes are listed in output
>>> produce_xml()
'<?xml version="1.0" encoding="UTF-8"?>\n<foo aaa="1.0" bbb="2.0"><bar ccc="3.0">Hello</bar><whiz>Goodbye</whiz></foo>'
@ -69,4 +75,4 @@ r"""
>>> produce_xml_fragment()
'<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>'
"""
"""}