diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
index 4bbaa1cd43..4883e0bfc8 100644
--- a/django/db/backends/__init__.py
+++ b/django/db/backends/__init__.py
@@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local):
self.settings_dict = settings_dict
self.alias = alias
self.vendor = 'unknown'
+ self.use_debug_cursor = None
def __eq__(self, other):
return self.settings_dict == other.settings_dict
@@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local):
def cursor(self):
from django.conf import settings
cursor = self._cursor()
- if settings.DEBUG:
+ if (self.use_debug_cursor or
+ (self.use_debug_cursor is None and settings.DEBUG)):
return self.make_debug_cursor(cursor)
return cursor
diff --git a/django/test/testcases.py b/django/test/testcases.py
index 06b6eb39f4..65664a1f50 100644
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -1,4 +1,5 @@
import re
+import sys
from urlparse import urlsplit, urlunsplit
from xml.dom.minidom import parseString, Node
@@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner):
for conn in connections:
transaction.rollback_unless_managed(using=conn)
+class _AssertNumQueriesContext(object):
+ def __init__(self, test_case, num, connection):
+ self.test_case = test_case
+ self.num = num
+ self.connection = connection
+
+ def __enter__(self):
+ self.old_debug_cursor = self.connection.use_debug_cursor
+ self.connection.use_debug_cursor = True
+ self.starting_queries = len(self.connection.queries)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_type is not None:
+ return
+
+ self.connection.use_debug_cursor = self.old_debug_cursor
+ final_queries = len(self.connection.queries)
+ executed = final_queries - self.starting_queries
+
+ self.test_case.assertEqual(
+ executed, self.num, "%d queries executed, %d expected" % (
+ executed, self.num
+ )
+ )
+
+
class TransactionTestCase(unittest.TestCase):
# The class we'll use for the test client self.client.
# Can be overridden in derived classes.
@@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase):
def assertQuerysetEqual(self, qs, values, transform=repr):
return self.assertEqual(map(transform, qs), values)
+ def assertNumQueries(self, num, func=None, *args, **kwargs):
+ using = kwargs.pop("using", DEFAULT_DB_ALIAS)
+ connection = connections[using]
+
+ context = _AssertNumQueriesContext(self, num, connection)
+ if func is None:
+ return context
+
+ # Basically emulate the `with` statement here.
+
+ context.__enter__()
+ try:
+ func(*args, **kwargs)
+ finally:
+ context.__exit__(*sys.exc_info())
+
def connections_support_transactions():
"""
Returns True if all connections support transactions. This is messy
diff --git a/docs/topics/testing.txt b/docs/topics/testing.txt
index 81cc4809d5..465807021a 100644
--- a/docs/topics/testing.txt
+++ b/docs/topics/testing.txt
@@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
implicit ordering, you will need to apply a ``order_by()`` clause to your
queryset to ensure that the test will pass reliably.
+.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):
+
+ .. versionadded:: 1.3
+
+ Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
+ ``num`` database queries are executed.
+
+ If a ``"using"`` key is present in ``kwargs`` it is used as the database
+ alias for which to check the number of queries. If you wish to call a
+ function with a ``using`` parameter you can do it by wrapping the call with
+ a ``lambda`` to add an extra parameter::
+
+ self.assertNumQueries(7, lambda: my_function(using=7))
+
+ If you're using Python 2.5 or greater you can also use this as a context
+ manager::
+
+ # This is necessary in Python 2.5 to enable the with statement, in 2.6
+ # and up it is no longer necessary.
+ from __future__ import with_statement
+
+ with self.assertNumQueries(2):
+ Person.objects.create(name="Aaron")
+ Person.objects.create(name="Daniel")
+
+
.. _topics-testing-email:
E-mail services
diff --git a/tests/modeltests/select_related/tests.py b/tests/modeltests/select_related/tests.py
index a2111026cd..301ce93b39 100644
--- a/tests/modeltests/select_related/tests.py
+++ b/tests/modeltests/select_related/tests.py
@@ -1,6 +1,4 @@
from django.test import TestCase
-from django.conf import settings
-from django import db
from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
@@ -36,73 +34,73 @@ class SelectRelatedTests(TestCase):
# queries so we'll set it to True here and reset it at the end of the
# test case.
self.create_base_data()
- settings.DEBUG = True
- db.reset_queries()
-
- def tearDown(self):
- settings.DEBUG = False
def test_access_fks_without_select_related(self):
"""
Normally, accessing FKs doesn't fill in related objects
"""
- fly = Species.objects.get(name="melanogaster")
- domain = fly.genus.family.order.klass.phylum.kingdom.domain
- self.assertEqual(domain.name, 'Eukaryota')
- self.assertEqual(len(db.connection.queries), 8)
+ def test():
+ fly = Species.objects.get(name="melanogaster")
+ domain = fly.genus.family.order.klass.phylum.kingdom.domain
+ self.assertEqual(domain.name, 'Eukaryota')
+ self.assertNumQueries(8, test)
def test_access_fks_with_select_related(self):
"""
A select_related() call will fill in those related objects without any
extra queries
"""
- person = Species.objects.select_related(depth=10).get(name="sapiens")
- domain = person.genus.family.order.klass.phylum.kingdom.domain
- self.assertEqual(domain.name, 'Eukaryota')
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ person = Species.objects.select_related(depth=10).get(name="sapiens")
+ domain = person.genus.family.order.klass.phylum.kingdom.domain
+ self.assertEqual(domain.name, 'Eukaryota')
+ self.assertNumQueries(1, test)
def test_list_without_select_related(self):
"""
select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior without select_related.
"""
- world = Species.objects.all()
- families = [o.genus.family.name for o in world]
- self.assertEqual(families, [
- 'Drosophilidae',
- 'Hominidae',
- 'Fabaceae',
- 'Amanitacae',
- ])
- self.assertEqual(len(db.connection.queries), 9)
+ def test():
+ world = Species.objects.all()
+ families = [o.genus.family.name for o in world]
+ self.assertEqual(families, [
+ 'Drosophilidae',
+ 'Hominidae',
+ 'Fabaceae',
+ 'Amanitacae',
+ ])
+ self.assertNumQueries(9, test)
def test_list_with_select_related(self):
"""
select_related() also of course applies to entire lists, not just
items. This test verifies the expected behavior with select_related.
"""
- world = Species.objects.all().select_related()
- families = [o.genus.family.name for o in world]
- self.assertEqual(families, [
- 'Drosophilidae',
- 'Hominidae',
- 'Fabaceae',
- 'Amanitacae',
- ])
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ world = Species.objects.all().select_related()
+ families = [o.genus.family.name for o in world]
+ self.assertEqual(families, [
+ 'Drosophilidae',
+ 'Hominidae',
+ 'Fabaceae',
+ 'Amanitacae',
+ ])
+ self.assertNumQueries(1, test)
def test_depth(self, depth=1, expected=7):
"""
The "depth" argument to select_related() will stop the descent at a
particular level.
"""
- pea = Species.objects.select_related(depth=depth).get(name="sativum")
- self.assertEqual(
- pea.genus.family.order.klass.phylum.kingdom.domain.name,
- 'Eukaryota'
- )
+ def test():
+ pea = Species.objects.select_related(depth=depth).get(name="sativum")
+ self.assertEqual(
+ pea.genus.family.order.klass.phylum.kingdom.domain.name,
+ 'Eukaryota'
+ )
# Notice: one fewer queries than above because of depth=1
- self.assertEqual(len(db.connection.queries), expected)
+ self.assertNumQueries(expected, test)
def test_larger_depth(self):
"""
@@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase):
The "depth" argument to select_related() will stop the descent at a
particular level. This can be used on lists as well.
"""
- world = Species.objects.all().select_related(depth=2)
- orders = [o.genus.family.order.name for o in world]
- self.assertEqual(orders,
- ['Diptera', 'Primates', 'Fabales', 'Agaricales'])
- self.assertEqual(len(db.connection.queries), 5)
+ def test():
+ world = Species.objects.all().select_related(depth=2)
+ orders = [o.genus.family.order.name for o in world]
+ self.assertEqual(orders,
+ ['Diptera', 'Primates', 'Fabales', 'Agaricales'])
+ self.assertNumQueries(5, test)
def test_select_related_with_extra(self):
s = Species.objects.all().select_related(depth=1)\
@@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase):
In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before.
"""
- world = Species.objects.select_related('genus__family')
- families = [o.genus.family.name for o in world]
- self.assertEqual(families,
- ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ world = Species.objects.select_related('genus__family')
+ families = [o.genus.family.name for o in world]
+ self.assertEqual(families,
+ ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
+ self.assertNumQueries(1, test)
def test_more_certain_fields(self):
"""
In this case, we explicitly say to select the 'genus' and
'genus.family' models, leading to the same number of queries as before.
"""
- world = Species.objects.filter(genus__name='Amanita')\
- .select_related('genus__family')
- orders = [o.genus.family.order.name for o in world]
- self.assertEqual(orders, [u'Agaricales'])
- self.assertEqual(len(db.connection.queries), 2)
+ def test():
+ world = Species.objects.filter(genus__name='Amanita')\
+ .select_related('genus__family')
+ orders = [o.genus.family.order.name for o in world]
+ self.assertEqual(orders, [u'Agaricales'])
+ self.assertNumQueries(2, test)
def test_field_traversal(self):
- s = Species.objects.all().select_related('genus__family__order'
- ).order_by('id')[0:1].get().genus.family.order.name
- self.assertEqual(s, u'Diptera')
- self.assertEqual(len(db.connection.queries), 1)
+ def test():
+ s = Species.objects.all().select_related('genus__family__order'
+ ).order_by('id')[0:1].get().genus.family.order.name
+ self.assertEqual(s, u'Diptera')
+ self.assertNumQueries(1, test)
def test_depth_fields_fails(self):
self.assertRaises(TypeError,
diff --git a/tests/modeltests/validation/test_unique.py b/tests/modeltests/validation/test_unique.py
index 2b824ae3b2..d239b8480c 100644
--- a/tests/modeltests/validation/test_unique.py
+++ b/tests/modeltests/validation/test_unique.py
@@ -2,9 +2,11 @@ import datetime
from django.conf import settings
from django.db import connection
+from django.test import TestCase
from django.utils import unittest
-from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate
+from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
+ UniqueForDateModel, ModelToValidate)
class GetUniqueCheckTests(unittest.TestCase):
@@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase):
), m._get_unique_checks(exclude='start_date')
)
-class PerformUniqueChecksTest(unittest.TestCase):
- def setUp(self):
- # Set debug to True to gain access to connection.queries.
- self._old_debug, settings.DEBUG = settings.DEBUG, True
- super(PerformUniqueChecksTest, self).setUp()
-
- def tearDown(self):
- # Restore old debug value.
- settings.DEBUG = self._old_debug
- super(PerformUniqueChecksTest, self).tearDown()
-
+class PerformUniqueChecksTest(TestCase):
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
# Regression test for #12560
- query_count = len(connection.queries)
- mtv = ModelToValidate(number=10, name='Some Name')
- setattr(mtv, '_adding', True)
- mtv.full_clean()
- self.assertEqual(query_count, len(connection.queries))
+ def test():
+ mtv = ModelToValidate(number=10, name='Some Name')
+ setattr(mtv, '_adding', True)
+ mtv.full_clean()
+ self.assertNumQueries(0, test)
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
# Regression test for #12560
- query_count = len(connection.queries)
- mtv = ModelToValidate(number=10, name='Some Name', id=123)
- setattr(mtv, '_adding', True)
- mtv.full_clean()
- self.assertEqual(query_count + 1, len(connection.queries))
+ def test():
+ mtv = ModelToValidate(number=10, name='Some Name', id=123)
+ setattr(mtv, '_adding', True)
+ mtv.full_clean()
+ self.assertNumQueries(1, test)
def test_primary_key_unique_check_not_performed_when_not_adding(self):
# Regression test for #12132
- query_count= len(connection.queries)
- mtv = ModelToValidate(number=10, name='Some Name')
- mtv.full_clean()
- self.assertEqual(query_count, len(connection.queries))
-
+ def test():
+ mtv = ModelToValidate(number=10, name='Some Name')
+ mtv.full_clean()
+ self.assertNumQueries(0, test)
diff --git a/tests/modeltests/validation/tests.py b/tests/modeltests/validation/tests.py
index 4dff410d54..142688f3d3 100644
--- a/tests/modeltests/validation/tests.py
+++ b/tests/modeltests/validation/tests.py
@@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate
# Import other tests for this package.
from modeltests.validation.validators import TestModelsWithValidators
-from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest
+from modeltests.validation.test_unique import (GetUniqueCheckTests,
+ PerformUniqueChecksTest)
from modeltests.validation.test_custom_messages import CustomMessagesTest
@@ -111,4 +112,3 @@ class ModelFormsTests(TestCase):
article = Article(author_id=self.author.id)
form = ArticleForm(data, instance=article)
self.assertEqual(form.errors.keys(), ['pub_date'])
-
diff --git a/tests/regressiontests/defer_regress/tests.py b/tests/regressiontests/defer_regress/tests.py
index affb0e2405..6f1b023b5b 100644
--- a/tests/regressiontests/defer_regress/tests.py
+++ b/tests/regressiontests/defer_regress/tests.py
@@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf
class DeferRegressionTest(TestCase):
- def assert_num_queries(self, n, func, *args, **kwargs):
- old_DEBUG = settings.DEBUG
- settings.DEBUG = True
- starting_queries = len(connection.queries)
- try:
- func(*args, **kwargs)
- finally:
- settings.DEBUG = old_DEBUG
- self.assertEqual(starting_queries + n, len(connection.queries))
-
-
def test_basic(self):
# Deferred fields should really be deferred and not accidentally use
# the field's default value just because they aren't passed to __init__
@@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase):
def test():
self.assertEqual(obj.name, "first")
self.assertEqual(obj.other_value, 0)
- self.assert_num_queries(0, test)
+ self.assertNumQueries(0, test)
def test():
self.assertEqual(obj.value, 42)
- self.assert_num_queries(1, test)
+ self.assertNumQueries(1, test)
def test():
self.assertEqual(obj.text, "xyzzy")
- self.assert_num_queries(1, test)
+ self.assertNumQueries(1, test)
def test():
self.assertEqual(obj.text, "xyzzy")
- self.assert_num_queries(0, test)
+ self.assertNumQueries(0, test)
# Regression test for #10695. Make sure different instances don't
# inadvertently share data in the deferred descriptor objects.
diff --git a/tests/regressiontests/forms/models.py b/tests/regressiontests/forms/models.py
index 028ff9bad2..a4891df06e 100644
--- a/tests/regressiontests/forms/models.py
+++ b/tests/regressiontests/forms/models.py
@@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
import datetime
-import tempfile
import shutil
+import tempfile
-from django.db import models, connection
-from django.conf import settings
+from django.db import models
# Can't import as "forms" due to implementation details in the test suite (the
# current file is called "forms" and is already imported).
from django import forms as django_forms
@@ -77,19 +76,13 @@ class TestTicket12510(TestCase):
''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). '''
def setUp(self):
self.groups = [Group.objects.create(name=name) for name in 'abc']
- self.old_debug = settings.DEBUG
- # turn debug on to get access to connection.queries
- settings.DEBUG = True
-
- def tearDown(self):
- settings.DEBUG = self.old_debug
def test_choices_not_fetched_when_not_rendering(self):
- initial_queries = len(connection.queries)
- field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
- self.assertEqual('a', field.clean(self.groups[0].pk).name)
+ def test():
+ field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
+ self.assertEqual('a', field.clean(self.groups[0].pk).name)
# only one query is required to pull the model from DB
- self.assertEqual(initial_queries+1, len(connection.queries))
+ self.assertNumQueries(1, test)
class ModelFormCallableModelDefault(TestCase):
def test_no_empty_option(self):
diff --git a/tests/regressiontests/model_forms_regress/tests.py b/tests/regressiontests/model_forms_regress/tests.py
index 397651a6b8..d695104d50 100644
--- a/tests/regressiontests/model_forms_regress/tests.py
+++ b/tests/regressiontests/model_forms_regress/tests.py
@@ -1,10 +1,8 @@
import unittest
from datetime import date
-from django import db
from django import forms
from django.forms.models import modelform_factory, ModelChoiceField
-from django.conf import settings
from django.test import TestCase
from django.core.exceptions import FieldError, ValidationError
from django.core.files.uploadedfile import SimpleUploadedFile
@@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \
class ModelMultipleChoiceFieldTests(TestCase):
-
- def setUp(self):
- self.old_debug = settings.DEBUG
- settings.DEBUG = True
-
- def tearDown(self):
- settings.DEBUG = self.old_debug
-
def test_model_multiple_choice_number_of_queries(self):
"""
Test that ModelMultipleChoiceField does O(1) queries instead of
@@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase):
for i in range(30):
Person.objects.create(name="Person %s" % i)
- db.reset_queries()
f = forms.ModelMultipleChoiceField(queryset=Person.objects.all())
- selected = f.clean([1, 3, 5, 7, 9])
- self.assertEquals(len(db.connection.queries), 1)
+ self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9])
class TripleForm(forms.ModelForm):
class Meta:
@@ -312,7 +300,7 @@ class InvalidFieldAndFactory(TestCase):
model = Person
fields = ('name', 'no-field')
except FieldError, e:
- # Make sure the exception contains some reference to the
+ # Make sure the exception contains some reference to the
# field responsible for the problem.
self.assertTrue('no-field' in e.args[0])
else:
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index 4ccb58440a..ab35feccf1 100644
--- a/tests/regressiontests/select_related_onetoone/tests.py
+++ b/tests/regressiontests/select_related_onetoone/tests.py
@@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
class ReverseSelectRelatedTestCase(TestCase):
def setUp(self):
- # Explicitly enable debug for these tests - we need to count
- # the queries that have been issued.
- self.old_debug = settings.DEBUG
- settings.DEBUG = True
-
user = User.objects.create(username="test")
userprofile = UserProfile.objects.create(user=user, state="KS",
city="Lawrence")
@@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase):
results=results2)
StatDetails.objects.create(base_stats=advstat, comments=250)
- db.reset_queries()
-
- def assertQueries(self, queries):
- self.assertEqual(len(db.connection.queries), queries)
-
- def tearDown(self):
- settings.DEBUG = self.old_debug
-
def test_basic(self):
- u = User.objects.select_related("userprofile").get(username="test")
- self.assertEqual(u.userprofile.state, "KS")
- self.assertQueries(1)
+ def test():
+ u = User.objects.select_related("userprofile").get(username="test")
+ self.assertEqual(u.userprofile.state, "KS")
+ self.assertNumQueries(1, test)
def test_follow_next_level(self):
- u = User.objects.select_related("userstat__results").get(username="test")
- self.assertEqual(u.userstat.posts, 150)
- self.assertEqual(u.userstat.results.results, 'first results')
- self.assertQueries(1)
+ def test():
+ u = User.objects.select_related("userstat__results").get(username="test")
+ self.assertEqual(u.userstat.posts, 150)
+ self.assertEqual(u.userstat.results.results, 'first results')
+ self.assertNumQueries(1, test)
def test_follow_two(self):
- u = User.objects.select_related("userprofile", "userstat").get(username="test")
- self.assertEqual(u.userprofile.state, "KS")
- self.assertEqual(u.userstat.posts, 150)
- self.assertQueries(1)
+ def test():
+ u = User.objects.select_related("userprofile", "userstat").get(username="test")
+ self.assertEqual(u.userprofile.state, "KS")
+ self.assertEqual(u.userstat.posts, 150)
+ self.assertNumQueries(1, test)
def test_follow_two_next_level(self):
- u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
- self.assertEqual(u.userstat.results.results, 'first results')
- self.assertEqual(u.userstat.statdetails.comments, 259)
- self.assertQueries(1)
+ def test():
+ u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
+ self.assertEqual(u.userstat.results.results, 'first results')
+ self.assertEqual(u.userstat.statdetails.comments, 259)
+ self.assertNumQueries(1, test)
def test_forward_and_back(self):
- stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
- self.assertEqual(stat.user.userprofile.state, 'KS')
- self.assertEqual(stat.user.userstat.posts, 150)
- self.assertQueries(1)
+ def test():
+ stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
+ self.assertEqual(stat.user.userprofile.state, 'KS')
+ self.assertEqual(stat.user.userstat.posts, 150)
+ self.assertNumQueries(1, test)
def test_back_and_forward(self):
- u = User.objects.select_related("userstat").get(username="test")
- self.assertEqual(u.userstat.user.username, 'test')
- self.assertQueries(1)
+ def test():
+ u = User.objects.select_related("userstat").get(username="test")
+ self.assertEqual(u.userstat.user.username, 'test')
+ self.assertNumQueries(1, test)
def test_not_followed_by_default(self):
- u = User.objects.select_related().get(username="test")
- self.assertEqual(u.userstat.posts, 150)
- self.assertQueries(2)
+ def test():
+ u = User.objects.select_related().get(username="test")
+ self.assertEqual(u.userstat.posts, 150)
+ self.assertNumQueries(2, test)
def test_follow_from_child_class(self):
- stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
- self.assertEqual(stat.statdetails.comments, 250)
- self.assertEqual(stat.user.username, 'bob')
- self.assertQueries(1)
+ def test():
+ stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
+ self.assertEqual(stat.statdetails.comments, 250)
+ self.assertEqual(stat.user.username, 'bob')
+ self.assertNumQueries(1, test)
def test_follow_inheritance(self):
- stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
- self.assertEqual(stat.advanceduserstat.posts, 200)
- self.assertEqual(stat.user.username, 'bob')
- self.assertEqual(stat.advanceduserstat.user.username, 'bob')
- self.assertQueries(1)
+ def test():
+ stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
+ self.assertEqual(stat.advanceduserstat.posts, 200)
+ self.assertEqual(stat.user.username, 'bob')
+ self.assertEqual(stat.advanceduserstat.user.username, 'bob')
+ self.assertNumQueries(1, test)
def test_nullable_relation(self):
im = Image.objects.create(name="imag1")
diff --git a/tests/regressiontests/test_utils/models.py b/tests/regressiontests/test_utils/models.py
index e69de29bb2..4da7a07bbf 100644
--- a/tests/regressiontests/test_utils/models.py
+++ b/tests/regressiontests/test_utils/models.py
@@ -0,0 +1,5 @@
+from django.db import models
+
+
+class Person(models.Model):
+ name = models.CharField(max_length=100)
diff --git a/tests/regressiontests/test_utils/python_25.py b/tests/regressiontests/test_utils/python_25.py
new file mode 100644
index 0000000000..a1e8a94d1e
--- /dev/null
+++ b/tests/regressiontests/test_utils/python_25.py
@@ -0,0 +1,30 @@
+from __future__ import with_statement
+
+from django.test import TestCase
+
+from models import Person
+
+
+class AssertNumQueriesTests(TestCase):
+ def test_simple(self):
+ with self.assertNumQueries(0):
+ pass
+
+ with self.assertNumQueries(1):
+ # Guy who wrote Linux
+ Person.objects.create(name="Linus Torvalds")
+
+ with self.assertNumQueries(2):
+ # Guy who owns the bagel place I like
+ Person.objects.create(name="Uncle Ricky")
+ self.assertEqual(Person.objects.count(), 2)
+
+ def test_failure(self):
+ with self.assertRaises(AssertionError) as exc_info:
+ with self.assertNumQueries(2):
+ Person.objects.count()
+ self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected")
+
+ with self.assertRaises(TypeError):
+ with self.assertNumQueries(4000):
+ raise TypeError
diff --git a/tests/regressiontests/test_utils/tests.py b/tests/regressiontests/test_utils/tests.py
index a2539bf8c6..4f92a402cc 100644
--- a/tests/regressiontests/test_utils/tests.py
+++ b/tests/regressiontests/test_utils/tests.py
@@ -1,6 +1,12 @@
-r"""
+import sys
+
+if sys.version_info >= (2, 5):
+ from python_25 import AssertNumQueriesTests
+
+
+__test__ = {"API_TEST": r"""
# Some checks of the doctest output normalizer.
-# Standard doctests do fairly
+# Standard doctests do fairly
>>> from django.utils import simplejson
>>> from django.utils.xmlutils import SimplerXMLGenerator
>>> from StringIO import StringIO
@@ -55,7 +61,7 @@ r"""
>>> produce_json()
'["foo", {"whiz": 42, "bar": ["baz", null, 1.0, 2]}]'
-# XML output is normalized for attribute order, so it doesn't matter
+# XML output is normalized for attribute order, so it doesn't matter
# which order XML element attributes are listed in output
>>> produce_xml()
'\nHelloGoodbye'
@@ -69,4 +75,4 @@ r"""
>>> produce_xml_fragment()
'Hello'
-"""
\ No newline at end of file
+"""}