From 5506653b777d7547d21ea2d74e9588fb94314b77 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Tue, 12 Oct 2010 03:33:19 +0000 Subject: [PATCH] Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries. git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37 --- django/db/backends/__init__.py | 4 +- django/test/testcases.py | 44 +++++++ docs/topics/testing.txt | 26 ++++ tests/modeltests/select_related/tests.py | 118 +++++++++--------- tests/modeltests/validation/test_unique.py | 45 +++---- tests/modeltests/validation/tests.py | 4 +- tests/regressiontests/defer_regress/tests.py | 19 +-- tests/regressiontests/forms/models.py | 19 +-- .../model_forms_regress/tests.py | 16 +-- .../select_related_onetoone/tests.py | 90 +++++++------ tests/regressiontests/test_utils/models.py | 5 + tests/regressiontests/test_utils/python_25.py | 30 +++++ tests/regressiontests/test_utils/tests.py | 14 ++- 13 files changed, 253 insertions(+), 181 deletions(-) create mode 100644 tests/regressiontests/test_utils/python_25.py diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py index 4bbaa1cd43..4883e0bfc8 100644 --- a/django/db/backends/__init__.py +++ b/django/db/backends/__init__.py @@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local): self.settings_dict = settings_dict self.alias = alias self.vendor = 'unknown' + self.use_debug_cursor = None def __eq__(self, other): return self.settings_dict == other.settings_dict @@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local): def cursor(self): from django.conf import settings cursor = self._cursor() - if settings.DEBUG: + if (self.use_debug_cursor or + (self.use_debug_cursor is None and settings.DEBUG)): return self.make_debug_cursor(cursor) return cursor diff --git a/django/test/testcases.py b/django/test/testcases.py index 06b6eb39f4..65664a1f50 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1,4 +1,5 @@ import re +import sys from urlparse import urlsplit, urlunsplit from xml.dom.minidom import parseString, Node @@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner): for conn in connections: transaction.rollback_unless_managed(using=conn) +class _AssertNumQueriesContext(object): + def __init__(self, test_case, num, connection): + self.test_case = test_case + self.num = num + self.connection = connection + + def __enter__(self): + self.old_debug_cursor = self.connection.use_debug_cursor + self.connection.use_debug_cursor = True + self.starting_queries = len(self.connection.queries) + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is not None: + return + + self.connection.use_debug_cursor = self.old_debug_cursor + final_queries = len(self.connection.queries) + executed = final_queries - self.starting_queries + + self.test_case.assertEqual( + executed, self.num, "%d queries executed, %d expected" % ( + executed, self.num + ) + ) + + class TransactionTestCase(unittest.TestCase): # The class we'll use for the test client self.client. # Can be overridden in derived classes. @@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase): def assertQuerysetEqual(self, qs, values, transform=repr): return self.assertEqual(map(transform, qs), values) + def assertNumQueries(self, num, func=None, *args, **kwargs): + using = kwargs.pop("using", DEFAULT_DB_ALIAS) + connection = connections[using] + + context = _AssertNumQueriesContext(self, num, connection) + if func is None: + return context + + # Basically emulate the `with` statement here. + + context.__enter__() + try: + func(*args, **kwargs) + finally: + context.__exit__(*sys.exc_info()) + def connections_support_transactions(): """ Returns True if all connections support transactions. This is messy diff --git a/docs/topics/testing.txt b/docs/topics/testing.txt index 81cc4809d5..465807021a 100644 --- a/docs/topics/testing.txt +++ b/docs/topics/testing.txt @@ -1372,6 +1372,32 @@ cause of an failure in your test suite. implicit ordering, you will need to apply a ``order_by()`` clause to your queryset to ensure that the test will pass reliably. +.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs): + + .. versionadded:: 1.3 + + Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that + ``num`` database queries are executed. + + If a ``"using"`` key is present in ``kwargs`` it is used as the database + alias for which to check the number of queries. If you wish to call a + function with a ``using`` parameter you can do it by wrapping the call with + a ``lambda`` to add an extra parameter:: + + self.assertNumQueries(7, lambda: my_function(using=7)) + + If you're using Python 2.5 or greater you can also use this as a context + manager:: + + # This is necessary in Python 2.5 to enable the with statement, in 2.6 + # and up it is no longer necessary. + from __future__ import with_statement + + with self.assertNumQueries(2): + Person.objects.create(name="Aaron") + Person.objects.create(name="Daniel") + + .. _topics-testing-email: E-mail services diff --git a/tests/modeltests/select_related/tests.py b/tests/modeltests/select_related/tests.py index a2111026cd..301ce93b39 100644 --- a/tests/modeltests/select_related/tests.py +++ b/tests/modeltests/select_related/tests.py @@ -1,6 +1,4 @@ from django.test import TestCase -from django.conf import settings -from django import db from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species @@ -36,73 +34,73 @@ class SelectRelatedTests(TestCase): # queries so we'll set it to True here and reset it at the end of the # test case. self.create_base_data() - settings.DEBUG = True - db.reset_queries() - - def tearDown(self): - settings.DEBUG = False def test_access_fks_without_select_related(self): """ Normally, accessing FKs doesn't fill in related objects """ - fly = Species.objects.get(name="melanogaster") - domain = fly.genus.family.order.klass.phylum.kingdom.domain - self.assertEqual(domain.name, 'Eukaryota') - self.assertEqual(len(db.connection.queries), 8) + def test(): + fly = Species.objects.get(name="melanogaster") + domain = fly.genus.family.order.klass.phylum.kingdom.domain + self.assertEqual(domain.name, 'Eukaryota') + self.assertNumQueries(8, test) def test_access_fks_with_select_related(self): """ A select_related() call will fill in those related objects without any extra queries """ - person = Species.objects.select_related(depth=10).get(name="sapiens") - domain = person.genus.family.order.klass.phylum.kingdom.domain - self.assertEqual(domain.name, 'Eukaryota') - self.assertEqual(len(db.connection.queries), 1) + def test(): + person = Species.objects.select_related(depth=10).get(name="sapiens") + domain = person.genus.family.order.klass.phylum.kingdom.domain + self.assertEqual(domain.name, 'Eukaryota') + self.assertNumQueries(1, test) def test_list_without_select_related(self): """ select_related() also of course applies to entire lists, not just items. This test verifies the expected behavior without select_related. """ - world = Species.objects.all() - families = [o.genus.family.name for o in world] - self.assertEqual(families, [ - 'Drosophilidae', - 'Hominidae', - 'Fabaceae', - 'Amanitacae', - ]) - self.assertEqual(len(db.connection.queries), 9) + def test(): + world = Species.objects.all() + families = [o.genus.family.name for o in world] + self.assertEqual(families, [ + 'Drosophilidae', + 'Hominidae', + 'Fabaceae', + 'Amanitacae', + ]) + self.assertNumQueries(9, test) def test_list_with_select_related(self): """ select_related() also of course applies to entire lists, not just items. This test verifies the expected behavior with select_related. """ - world = Species.objects.all().select_related() - families = [o.genus.family.name for o in world] - self.assertEqual(families, [ - 'Drosophilidae', - 'Hominidae', - 'Fabaceae', - 'Amanitacae', - ]) - self.assertEqual(len(db.connection.queries), 1) + def test(): + world = Species.objects.all().select_related() + families = [o.genus.family.name for o in world] + self.assertEqual(families, [ + 'Drosophilidae', + 'Hominidae', + 'Fabaceae', + 'Amanitacae', + ]) + self.assertNumQueries(1, test) def test_depth(self, depth=1, expected=7): """ The "depth" argument to select_related() will stop the descent at a particular level. """ - pea = Species.objects.select_related(depth=depth).get(name="sativum") - self.assertEqual( - pea.genus.family.order.klass.phylum.kingdom.domain.name, - 'Eukaryota' - ) + def test(): + pea = Species.objects.select_related(depth=depth).get(name="sativum") + self.assertEqual( + pea.genus.family.order.klass.phylum.kingdom.domain.name, + 'Eukaryota' + ) # Notice: one fewer queries than above because of depth=1 - self.assertEqual(len(db.connection.queries), expected) + self.assertNumQueries(expected, test) def test_larger_depth(self): """ @@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase): The "depth" argument to select_related() will stop the descent at a particular level. This can be used on lists as well. """ - world = Species.objects.all().select_related(depth=2) - orders = [o.genus.family.order.name for o in world] - self.assertEqual(orders, - ['Diptera', 'Primates', 'Fabales', 'Agaricales']) - self.assertEqual(len(db.connection.queries), 5) + def test(): + world = Species.objects.all().select_related(depth=2) + orders = [o.genus.family.order.name for o in world] + self.assertEqual(orders, + ['Diptera', 'Primates', 'Fabales', 'Agaricales']) + self.assertNumQueries(5, test) def test_select_related_with_extra(self): s = Species.objects.all().select_related(depth=1)\ @@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase): In this case, we explicitly say to select the 'genus' and 'genus.family' models, leading to the same number of queries as before. """ - world = Species.objects.select_related('genus__family') - families = [o.genus.family.name for o in world] - self.assertEqual(families, - ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) - self.assertEqual(len(db.connection.queries), 1) + def test(): + world = Species.objects.select_related('genus__family') + families = [o.genus.family.name for o in world] + self.assertEqual(families, + ['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae']) + self.assertNumQueries(1, test) def test_more_certain_fields(self): """ In this case, we explicitly say to select the 'genus' and 'genus.family' models, leading to the same number of queries as before. """ - world = Species.objects.filter(genus__name='Amanita')\ - .select_related('genus__family') - orders = [o.genus.family.order.name for o in world] - self.assertEqual(orders, [u'Agaricales']) - self.assertEqual(len(db.connection.queries), 2) + def test(): + world = Species.objects.filter(genus__name='Amanita')\ + .select_related('genus__family') + orders = [o.genus.family.order.name for o in world] + self.assertEqual(orders, [u'Agaricales']) + self.assertNumQueries(2, test) def test_field_traversal(self): - s = Species.objects.all().select_related('genus__family__order' - ).order_by('id')[0:1].get().genus.family.order.name - self.assertEqual(s, u'Diptera') - self.assertEqual(len(db.connection.queries), 1) + def test(): + s = Species.objects.all().select_related('genus__family__order' + ).order_by('id')[0:1].get().genus.family.order.name + self.assertEqual(s, u'Diptera') + self.assertNumQueries(1, test) def test_depth_fields_fails(self): self.assertRaises(TypeError, diff --git a/tests/modeltests/validation/test_unique.py b/tests/modeltests/validation/test_unique.py index 2b824ae3b2..d239b8480c 100644 --- a/tests/modeltests/validation/test_unique.py +++ b/tests/modeltests/validation/test_unique.py @@ -2,9 +2,11 @@ import datetime from django.conf import settings from django.db import connection +from django.test import TestCase from django.utils import unittest -from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate +from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, + UniqueForDateModel, ModelToValidate) class GetUniqueCheckTests(unittest.TestCase): @@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase): ), m._get_unique_checks(exclude='start_date') ) -class PerformUniqueChecksTest(unittest.TestCase): - def setUp(self): - # Set debug to True to gain access to connection.queries. - self._old_debug, settings.DEBUG = settings.DEBUG, True - super(PerformUniqueChecksTest, self).setUp() - - def tearDown(self): - # Restore old debug value. - settings.DEBUG = self._old_debug - super(PerformUniqueChecksTest, self).tearDown() - +class PerformUniqueChecksTest(TestCase): def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self): # Regression test for #12560 - query_count = len(connection.queries) - mtv = ModelToValidate(number=10, name='Some Name') - setattr(mtv, '_adding', True) - mtv.full_clean() - self.assertEqual(query_count, len(connection.queries)) + def test(): + mtv = ModelToValidate(number=10, name='Some Name') + setattr(mtv, '_adding', True) + mtv.full_clean() + self.assertNumQueries(0, test) def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self): # Regression test for #12560 - query_count = len(connection.queries) - mtv = ModelToValidate(number=10, name='Some Name', id=123) - setattr(mtv, '_adding', True) - mtv.full_clean() - self.assertEqual(query_count + 1, len(connection.queries)) + def test(): + mtv = ModelToValidate(number=10, name='Some Name', id=123) + setattr(mtv, '_adding', True) + mtv.full_clean() + self.assertNumQueries(1, test) def test_primary_key_unique_check_not_performed_when_not_adding(self): # Regression test for #12132 - query_count= len(connection.queries) - mtv = ModelToValidate(number=10, name='Some Name') - mtv.full_clean() - self.assertEqual(query_count, len(connection.queries)) - + def test(): + mtv = ModelToValidate(number=10, name='Some Name') + mtv.full_clean() + self.assertNumQueries(0, test) diff --git a/tests/modeltests/validation/tests.py b/tests/modeltests/validation/tests.py index 4dff410d54..142688f3d3 100644 --- a/tests/modeltests/validation/tests.py +++ b/tests/modeltests/validation/tests.py @@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate # Import other tests for this package. from modeltests.validation.validators import TestModelsWithValidators -from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest +from modeltests.validation.test_unique import (GetUniqueCheckTests, + PerformUniqueChecksTest) from modeltests.validation.test_custom_messages import CustomMessagesTest @@ -111,4 +112,3 @@ class ModelFormsTests(TestCase): article = Article(author_id=self.author.id) form = ArticleForm(data, instance=article) self.assertEqual(form.errors.keys(), ['pub_date']) - diff --git a/tests/regressiontests/defer_regress/tests.py b/tests/regressiontests/defer_regress/tests.py index affb0e2405..6f1b023b5b 100644 --- a/tests/regressiontests/defer_regress/tests.py +++ b/tests/regressiontests/defer_regress/tests.py @@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf class DeferRegressionTest(TestCase): - def assert_num_queries(self, n, func, *args, **kwargs): - old_DEBUG = settings.DEBUG - settings.DEBUG = True - starting_queries = len(connection.queries) - try: - func(*args, **kwargs) - finally: - settings.DEBUG = old_DEBUG - self.assertEqual(starting_queries + n, len(connection.queries)) - - def test_basic(self): # Deferred fields should really be deferred and not accidentally use # the field's default value just because they aren't passed to __init__ @@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase): def test(): self.assertEqual(obj.name, "first") self.assertEqual(obj.other_value, 0) - self.assert_num_queries(0, test) + self.assertNumQueries(0, test) def test(): self.assertEqual(obj.value, 42) - self.assert_num_queries(1, test) + self.assertNumQueries(1, test) def test(): self.assertEqual(obj.text, "xyzzy") - self.assert_num_queries(1, test) + self.assertNumQueries(1, test) def test(): self.assertEqual(obj.text, "xyzzy") - self.assert_num_queries(0, test) + self.assertNumQueries(0, test) # Regression test for #10695. Make sure different instances don't # inadvertently share data in the deferred descriptor objects. diff --git a/tests/regressiontests/forms/models.py b/tests/regressiontests/forms/models.py index 028ff9bad2..a4891df06e 100644 --- a/tests/regressiontests/forms/models.py +++ b/tests/regressiontests/forms/models.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- import datetime -import tempfile import shutil +import tempfile -from django.db import models, connection -from django.conf import settings +from django.db import models # Can't import as "forms" due to implementation details in the test suite (the # current file is called "forms" and is already imported). from django import forms as django_forms @@ -77,19 +76,13 @@ class TestTicket12510(TestCase): ''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). ''' def setUp(self): self.groups = [Group.objects.create(name=name) for name in 'abc'] - self.old_debug = settings.DEBUG - # turn debug on to get access to connection.queries - settings.DEBUG = True - - def tearDown(self): - settings.DEBUG = self.old_debug def test_choices_not_fetched_when_not_rendering(self): - initial_queries = len(connection.queries) - field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) - self.assertEqual('a', field.clean(self.groups[0].pk).name) + def test(): + field = django_forms.ModelChoiceField(Group.objects.order_by('-name')) + self.assertEqual('a', field.clean(self.groups[0].pk).name) # only one query is required to pull the model from DB - self.assertEqual(initial_queries+1, len(connection.queries)) + self.assertNumQueries(1, test) class ModelFormCallableModelDefault(TestCase): def test_no_empty_option(self): diff --git a/tests/regressiontests/model_forms_regress/tests.py b/tests/regressiontests/model_forms_regress/tests.py index 397651a6b8..d695104d50 100644 --- a/tests/regressiontests/model_forms_regress/tests.py +++ b/tests/regressiontests/model_forms_regress/tests.py @@ -1,10 +1,8 @@ import unittest from datetime import date -from django import db from django import forms from django.forms.models import modelform_factory, ModelChoiceField -from django.conf import settings from django.test import TestCase from django.core.exceptions import FieldError, ValidationError from django.core.files.uploadedfile import SimpleUploadedFile @@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \ class ModelMultipleChoiceFieldTests(TestCase): - - def setUp(self): - self.old_debug = settings.DEBUG - settings.DEBUG = True - - def tearDown(self): - settings.DEBUG = self.old_debug - def test_model_multiple_choice_number_of_queries(self): """ Test that ModelMultipleChoiceField does O(1) queries instead of @@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase): for i in range(30): Person.objects.create(name="Person %s" % i) - db.reset_queries() f = forms.ModelMultipleChoiceField(queryset=Person.objects.all()) - selected = f.clean([1, 3, 5, 7, 9]) - self.assertEquals(len(db.connection.queries), 1) + self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9]) class TripleForm(forms.ModelForm): class Meta: @@ -312,7 +300,7 @@ class InvalidFieldAndFactory(TestCase): model = Person fields = ('name', 'no-field') except FieldError, e: - # Make sure the exception contains some reference to the + # Make sure the exception contains some reference to the # field responsible for the problem. self.assertTrue('no-field' in e.args[0]) else: diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py index 4ccb58440a..ab35feccf1 100644 --- a/tests/regressiontests/select_related_onetoone/tests.py +++ b/tests/regressiontests/select_related_onetoone/tests.py @@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, class ReverseSelectRelatedTestCase(TestCase): def setUp(self): - # Explicitly enable debug for these tests - we need to count - # the queries that have been issued. - self.old_debug = settings.DEBUG - settings.DEBUG = True - user = User.objects.create(username="test") userprofile = UserProfile.objects.create(user=user, state="KS", city="Lawrence") @@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase): results=results2) StatDetails.objects.create(base_stats=advstat, comments=250) - db.reset_queries() - - def assertQueries(self, queries): - self.assertEqual(len(db.connection.queries), queries) - - def tearDown(self): - settings.DEBUG = self.old_debug - def test_basic(self): - u = User.objects.select_related("userprofile").get(username="test") - self.assertEqual(u.userprofile.state, "KS") - self.assertQueries(1) + def test(): + u = User.objects.select_related("userprofile").get(username="test") + self.assertEqual(u.userprofile.state, "KS") + self.assertNumQueries(1, test) def test_follow_next_level(self): - u = User.objects.select_related("userstat__results").get(username="test") - self.assertEqual(u.userstat.posts, 150) - self.assertEqual(u.userstat.results.results, 'first results') - self.assertQueries(1) + def test(): + u = User.objects.select_related("userstat__results").get(username="test") + self.assertEqual(u.userstat.posts, 150) + self.assertEqual(u.userstat.results.results, 'first results') + self.assertNumQueries(1, test) def test_follow_two(self): - u = User.objects.select_related("userprofile", "userstat").get(username="test") - self.assertEqual(u.userprofile.state, "KS") - self.assertEqual(u.userstat.posts, 150) - self.assertQueries(1) + def test(): + u = User.objects.select_related("userprofile", "userstat").get(username="test") + self.assertEqual(u.userprofile.state, "KS") + self.assertEqual(u.userstat.posts, 150) + self.assertNumQueries(1, test) def test_follow_two_next_level(self): - u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") - self.assertEqual(u.userstat.results.results, 'first results') - self.assertEqual(u.userstat.statdetails.comments, 259) - self.assertQueries(1) + def test(): + u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") + self.assertEqual(u.userstat.results.results, 'first results') + self.assertEqual(u.userstat.statdetails.comments, 259) + self.assertNumQueries(1, test) def test_forward_and_back(self): - stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") - self.assertEqual(stat.user.userprofile.state, 'KS') - self.assertEqual(stat.user.userstat.posts, 150) - self.assertQueries(1) + def test(): + stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") + self.assertEqual(stat.user.userprofile.state, 'KS') + self.assertEqual(stat.user.userstat.posts, 150) + self.assertNumQueries(1, test) def test_back_and_forward(self): - u = User.objects.select_related("userstat").get(username="test") - self.assertEqual(u.userstat.user.username, 'test') - self.assertQueries(1) + def test(): + u = User.objects.select_related("userstat").get(username="test") + self.assertEqual(u.userstat.user.username, 'test') + self.assertNumQueries(1, test) def test_not_followed_by_default(self): - u = User.objects.select_related().get(username="test") - self.assertEqual(u.userstat.posts, 150) - self.assertQueries(2) + def test(): + u = User.objects.select_related().get(username="test") + self.assertEqual(u.userstat.posts, 150) + self.assertNumQueries(2, test) def test_follow_from_child_class(self): - stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) - self.assertEqual(stat.statdetails.comments, 250) - self.assertEqual(stat.user.username, 'bob') - self.assertQueries(1) + def test(): + stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) + self.assertEqual(stat.statdetails.comments, 250) + self.assertEqual(stat.user.username, 'bob') + self.assertNumQueries(1, test) def test_follow_inheritance(self): - stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) - self.assertEqual(stat.advanceduserstat.posts, 200) - self.assertEqual(stat.user.username, 'bob') - self.assertEqual(stat.advanceduserstat.user.username, 'bob') - self.assertQueries(1) + def test(): + stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) + self.assertEqual(stat.advanceduserstat.posts, 200) + self.assertEqual(stat.user.username, 'bob') + self.assertEqual(stat.advanceduserstat.user.username, 'bob') + self.assertNumQueries(1, test) def test_nullable_relation(self): im = Image.objects.create(name="imag1") diff --git a/tests/regressiontests/test_utils/models.py b/tests/regressiontests/test_utils/models.py index e69de29bb2..4da7a07bbf 100644 --- a/tests/regressiontests/test_utils/models.py +++ b/tests/regressiontests/test_utils/models.py @@ -0,0 +1,5 @@ +from django.db import models + + +class Person(models.Model): + name = models.CharField(max_length=100) diff --git a/tests/regressiontests/test_utils/python_25.py b/tests/regressiontests/test_utils/python_25.py new file mode 100644 index 0000000000..a1e8a94d1e --- /dev/null +++ b/tests/regressiontests/test_utils/python_25.py @@ -0,0 +1,30 @@ +from __future__ import with_statement + +from django.test import TestCase + +from models import Person + + +class AssertNumQueriesTests(TestCase): + def test_simple(self): + with self.assertNumQueries(0): + pass + + with self.assertNumQueries(1): + # Guy who wrote Linux + Person.objects.create(name="Linus Torvalds") + + with self.assertNumQueries(2): + # Guy who owns the bagel place I like + Person.objects.create(name="Uncle Ricky") + self.assertEqual(Person.objects.count(), 2) + + def test_failure(self): + with self.assertRaises(AssertionError) as exc_info: + with self.assertNumQueries(2): + Person.objects.count() + self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected") + + with self.assertRaises(TypeError): + with self.assertNumQueries(4000): + raise TypeError diff --git a/tests/regressiontests/test_utils/tests.py b/tests/regressiontests/test_utils/tests.py index a2539bf8c6..4f92a402cc 100644 --- a/tests/regressiontests/test_utils/tests.py +++ b/tests/regressiontests/test_utils/tests.py @@ -1,6 +1,12 @@ -r""" +import sys + +if sys.version_info >= (2, 5): + from python_25 import AssertNumQueriesTests + + +__test__ = {"API_TEST": r""" # Some checks of the doctest output normalizer. -# Standard doctests do fairly +# Standard doctests do fairly >>> from django.utils import simplejson >>> from django.utils.xmlutils import SimplerXMLGenerator >>> from StringIO import StringIO @@ -55,7 +61,7 @@ r""" >>> produce_json() '["foo", {"whiz": 42, "bar": ["baz", null, 1.0, 2]}]' -# XML output is normalized for attribute order, so it doesn't matter +# XML output is normalized for attribute order, so it doesn't matter # which order XML element attributes are listed in output >>> produce_xml() '\nHelloGoodbye' @@ -69,4 +75,4 @@ r""" >>> produce_xml_fragment() 'Hello' -""" \ No newline at end of file +"""}