Fixed #14774 -- the test client and assertNumQueries didn't work well together. Thanks to Jonas Obrist for the initial patch.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@15251 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2011-01-20 04:47:47 +00:00
parent 53dac996ef
commit 8308ad4f05
6 changed files with 70 additions and 12 deletions

View File

@ -6,8 +6,10 @@ from xml.dom.minidom import parseString, Node
from django.conf import settings
from django.core import mail
from django.core.management import call_command
from django.core.signals import request_started
from django.core.urlresolvers import clear_url_caches
from django.db import transaction, connection, connections, DEFAULT_DB_ALIAS
from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS,
reset_queries)
from django.http import QueryDict
from django.test import _doctest as doctest
from django.test.client import Client
@ -220,10 +222,12 @@ class _AssertNumQueriesContext(object):
self.old_debug_cursor = self.connection.use_debug_cursor
self.connection.use_debug_cursor = True
self.starting_queries = len(self.connection.queries)
request_started.disconnect(reset_queries)
return self
def __exit__(self, exc_type, exc_value, traceback):
self.connection.use_debug_cursor = self.old_debug_cursor
request_started.connect(reset_queries)
if exc_type is not None:
return

View File

@ -2,20 +2,13 @@ import sys
from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature
from models import Person
if sys.version_info >= (2, 5):
from tests_25 import AssertNumQueriesTests
from tests_25 import AssertNumQueriesContextManagerTests
class SkippingTestCase(TestCase):
def test_assert_num_queries(self):
def test_func():
raise ValueError
self.assertRaises(ValueError,
self.assertNumQueries, 2, test_func
)
def test_skip_unless_db_feature(self):
"A test that might be skipped is actually called."
# Total hack, but it works, just want an attribute that's always true.
@ -26,8 +19,37 @@ class SkippingTestCase(TestCase):
self.assertRaises(ValueError, test_func)
class SaveRestoreWarningState(TestCase):
class AssertNumQueriesTests(TestCase):
def test_assert_num_queries(self):
def test_func():
raise ValueError
self.assertRaises(ValueError,
self.assertNumQueries, 2, test_func
)
def test_assert_num_queries_with_client(self):
person = Person.objects.create(name='test')
self.assertNumQueries(
1,
self.client.get,
"/test_utils/get_person/%s/" % person.pk
)
self.assertNumQueries(
1,
self.client.get,
"/test_utils/get_person/%s/" % person.pk
)
def test_func():
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.assertNumQueries(2, test_func)
class SaveRestoreWarningState(TestCase):
def test_save_restore_warnings_state(self):
"""
Ensure save_warnings_state/restore_warnings_state work correctly.

View File

@ -5,7 +5,7 @@ from django.test import TestCase
from models import Person
class AssertNumQueriesTests(TestCase):
class AssertNumQueriesContextManagerTests(TestCase):
def test_simple(self):
with self.assertNumQueries(0):
pass
@ -26,3 +26,16 @@ class AssertNumQueriesTests(TestCase):
with self.assertRaises(TypeError):
with self.assertNumQueries(4000):
raise TypeError
def test_with_client(self):
person = Person.objects.create(name="test")
with self.assertNumQueries(1):
self.client.get("/test_utils/get_person/%s/" % person.pk)
with self.assertNumQueries(1):
self.client.get("/test_utils/get_person/%s/" % person.pk)
with self.assertNumQueries(2):
self.client.get("/test_utils/get_person/%s/" % person.pk)
self.client.get("/test_utils/get_person/%s/" % person.pk)

View File

@ -0,0 +1,8 @@
from django.conf.urls.defaults import patterns
import views
urlpatterns = patterns('',
(r'^get_person/(\d+)/$', views.get_person),
)

View File

@ -0,0 +1,7 @@
from django.http import HttpResponse
from django.shortcuts import get_object_or_404
from models import Person
def get_person(request, pk):
person = get_object_or_404(Person, pk=pk)
return HttpResponse(person.name)

View File

@ -1,5 +1,6 @@
from django.conf.urls.defaults import *
urlpatterns = patterns('',
# test_client modeltest urls
(r'^test_client/', include('modeltests.test_client.urls')),
@ -41,4 +42,7 @@ urlpatterns = patterns('',
# special headers views
(r'special_headers/', include('regressiontests.special_headers.urls')),
# test util views
(r'test_utils/', include('regressiontests.test_utils.urls')),
)