mirror of https://github.com/django/django.git
Added a context manager to capture queries while testing.
Also made some import cleanups while I was there. Refs #10399.
This commit is contained in:
parent
203c17c21e
commit
952ba5237e
|
@ -24,31 +24,30 @@ from django.core.exceptions import ValidationError, ImproperlyConfigured
|
||||||
from django.core.handlers.wsgi import WSGIHandler
|
from django.core.handlers.wsgi import WSGIHandler
|
||||||
from django.core.management import call_command
|
from django.core.management import call_command
|
||||||
from django.core.management.color import no_style
|
from django.core.management.color import no_style
|
||||||
from django.core.signals import request_started
|
|
||||||
from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer,
|
from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer,
|
||||||
WSGIServerException)
|
WSGIServerException)
|
||||||
from django.core.urlresolvers import clear_url_caches
|
from django.core.urlresolvers import clear_url_caches
|
||||||
from django.core.validators import EMPTY_VALUES
|
from django.core.validators import EMPTY_VALUES
|
||||||
from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS,
|
from django.db import connection, connections, DEFAULT_DB_ALIAS, transaction
|
||||||
reset_queries)
|
|
||||||
from django.forms.fields import CharField
|
from django.forms.fields import CharField
|
||||||
from django.http import QueryDict
|
from django.http import QueryDict
|
||||||
from django.test import _doctest as doctest
|
from django.test import _doctest as doctest
|
||||||
from django.test.client import Client
|
from django.test.client import Client
|
||||||
from django.test.html import HTMLParseError, parse_html
|
from django.test.html import HTMLParseError, parse_html
|
||||||
from django.test.signals import template_rendered
|
from django.test.signals import template_rendered
|
||||||
from django.test.utils import (override_settings, compare_xml, strip_quotes)
|
from django.test.utils import (CaptureQueriesContext, ContextList,
|
||||||
from django.test.utils import ContextList
|
override_settings, compare_xml, strip_quotes)
|
||||||
from django.utils import unittest as ut2
|
from django.utils import six, unittest as ut2
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
from django.utils import six
|
from django.utils.unittest import skipIf # Imported here for backward compatibility
|
||||||
from django.utils.unittest.util import safe_repr
|
from django.utils.unittest.util import safe_repr
|
||||||
from django.utils.unittest import skipIf
|
|
||||||
from django.views.static import serve
|
from django.views.static import serve
|
||||||
|
|
||||||
|
|
||||||
__all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
|
__all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
|
||||||
'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
|
'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
|
||||||
|
|
||||||
|
|
||||||
normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
|
normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
|
||||||
normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)",
|
normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)",
|
||||||
lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
|
lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
|
||||||
|
@ -168,28 +167,17 @@ class DocTestRunner(doctest.DocTestRunner):
|
||||||
transaction.rollback_unless_managed(using=conn)
|
transaction.rollback_unless_managed(using=conn)
|
||||||
|
|
||||||
|
|
||||||
class _AssertNumQueriesContext(object):
|
class _AssertNumQueriesContext(CaptureQueriesContext):
|
||||||
def __init__(self, test_case, num, connection):
|
def __init__(self, test_case, num, connection):
|
||||||
self.test_case = test_case
|
self.test_case = test_case
|
||||||
self.num = num
|
self.num = num
|
||||||
self.connection = connection
|
super(_AssertNumQueriesContext, self).__init__(connection)
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self.old_debug_cursor = self.connection.use_debug_cursor
|
|
||||||
self.connection.use_debug_cursor = True
|
|
||||||
self.starting_queries = len(self.connection.queries)
|
|
||||||
request_started.disconnect(reset_queries)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.connection.use_debug_cursor = self.old_debug_cursor
|
|
||||||
request_started.connect(reset_queries)
|
|
||||||
if exc_type is not None:
|
if exc_type is not None:
|
||||||
return
|
return
|
||||||
|
super(_AssertNumQueriesContext, self).__exit__(exc_type, exc_value, traceback)
|
||||||
final_queries = len(self.connection.queries)
|
executed = len(self)
|
||||||
executed = final_queries - self.starting_queries
|
|
||||||
|
|
||||||
self.test_case.assertEqual(
|
self.test_case.assertEqual(
|
||||||
executed, self.num, "%d queries executed, %d expected" % (
|
executed, self.num, "%d queries executed, %d expected" % (
|
||||||
executed, self.num
|
executed, self.num
|
||||||
|
@ -1051,7 +1039,6 @@ class LiveServerThread(threading.Thread):
|
||||||
http requests.
|
http requests.
|
||||||
"""
|
"""
|
||||||
if self.connections_override:
|
if self.connections_override:
|
||||||
from django.db import connections
|
|
||||||
# Override this thread's database connections with the ones
|
# Override this thread's database connections with the ones
|
||||||
# provided by the main thread.
|
# provided by the main thread.
|
||||||
for alias, conn in self.connections_override.items():
|
for alias, conn in self.connections_override.items():
|
||||||
|
|
|
@ -4,6 +4,8 @@ from xml.dom.minidom import parseString, Node
|
||||||
|
|
||||||
from django.conf import settings, UserSettingsHolder
|
from django.conf import settings, UserSettingsHolder
|
||||||
from django.core import mail
|
from django.core import mail
|
||||||
|
from django.core.signals import request_started
|
||||||
|
from django.db import reset_queries
|
||||||
from django.template import Template, loader, TemplateDoesNotExist
|
from django.template import Template, loader, TemplateDoesNotExist
|
||||||
from django.template.loaders import cached
|
from django.template.loaders import cached
|
||||||
from django.test.signals import template_rendered, setting_changed
|
from django.test.signals import template_rendered, setting_changed
|
||||||
|
@ -339,5 +341,42 @@ def strip_quotes(want, got):
|
||||||
got = got.strip()[2:-1]
|
got = got.strip()[2:-1]
|
||||||
return want, got
|
return want, got
|
||||||
|
|
||||||
|
|
||||||
def str_prefix(s):
|
def str_prefix(s):
|
||||||
return s % {'_': '' if six.PY3 else 'u'}
|
return s % {'_': '' if six.PY3 else 'u'}
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureQueriesContext(object):
|
||||||
|
"""
|
||||||
|
Context manager that captures queries executed by the specified connection.
|
||||||
|
"""
|
||||||
|
def __init__(self, connection):
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.captured_queries)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.captured_queries[index]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captured_queries)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def captured_queries(self):
|
||||||
|
return self.connection.queries[self.initial_queries:self.final_queries]
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.use_debug_cursor = self.connection.use_debug_cursor
|
||||||
|
self.connection.use_debug_cursor = True
|
||||||
|
self.initial_queries = len(self.connection.queries)
|
||||||
|
self.final_queries = None
|
||||||
|
request_started.disconnect(reset_queries)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.connection.use_debug_cursor = self.use_debug_cursor
|
||||||
|
request_started.connect(reset_queries)
|
||||||
|
if exc_type is not None:
|
||||||
|
return
|
||||||
|
self.final_queries = len(self.connection.queries)
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import absolute_import, unicode_literals
|
from __future__ import absolute_import, unicode_literals
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
from django.forms import EmailField, IntegerField
|
from django.forms import EmailField, IntegerField
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.template.loader import render_to_string
|
from django.template.loader import render_to_string
|
||||||
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
|
||||||
|
from django.test.html import HTMLParseError, parse_html
|
||||||
|
from django.test.utils import CaptureQueriesContext
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.unittest import skip
|
from django.utils.unittest import skip
|
||||||
|
|
||||||
|
@ -94,6 +98,60 @@ class AssertQuerysetEqualTests(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureQueriesContextManagerTests(TestCase):
|
||||||
|
urls = 'test_utils.urls'
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.person_pk = six.text_type(Person.objects.create(name='test').pk)
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
Person.objects.get(pk=self.person_pk)
|
||||||
|
self.assertEqual(len(captured_queries), 1)
|
||||||
|
self.assertIn(self.person_pk, captured_queries[0]['sql'])
|
||||||
|
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
pass
|
||||||
|
self.assertEqual(0, len(captured_queries))
|
||||||
|
|
||||||
|
def test_within(self):
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
Person.objects.get(pk=self.person_pk)
|
||||||
|
self.assertEqual(len(captured_queries), 1)
|
||||||
|
self.assertIn(self.person_pk, captured_queries[0]['sql'])
|
||||||
|
|
||||||
|
def test_nested(self):
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
Person.objects.count()
|
||||||
|
with CaptureQueriesContext(connection) as nested_captured_queries:
|
||||||
|
Person.objects.count()
|
||||||
|
self.assertEqual(1, len(nested_captured_queries))
|
||||||
|
self.assertEqual(2, len(captured_queries))
|
||||||
|
|
||||||
|
def test_failure(self):
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
with CaptureQueriesContext(connection):
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
def test_with_client(self):
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
|
||||||
|
self.assertEqual(len(captured_queries), 1)
|
||||||
|
self.assertIn(self.person_pk, captured_queries[0]['sql'])
|
||||||
|
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
|
||||||
|
self.assertEqual(len(captured_queries), 1)
|
||||||
|
self.assertIn(self.person_pk, captured_queries[0]['sql'])
|
||||||
|
|
||||||
|
with CaptureQueriesContext(connection) as captured_queries:
|
||||||
|
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
|
||||||
|
self.client.get("/test_utils/get_person/%s/" % self.person_pk)
|
||||||
|
self.assertEqual(len(captured_queries), 2)
|
||||||
|
self.assertIn(self.person_pk, captured_queries[0]['sql'])
|
||||||
|
self.assertIn(self.person_pk, captured_queries[1]['sql'])
|
||||||
|
|
||||||
|
|
||||||
class AssertNumQueriesContextManagerTests(TestCase):
|
class AssertNumQueriesContextManagerTests(TestCase):
|
||||||
urls = 'test_utils.urls'
|
urls = 'test_utils.urls'
|
||||||
|
|
||||||
|
@ -219,7 +277,6 @@ class SaveRestoreWarningState(TestCase):
|
||||||
# In reality this test could be satisfied by many broken implementations
|
# In reality this test could be satisfied by many broken implementations
|
||||||
# of save_warnings_state/restore_warnings_state (e.g. just
|
# of save_warnings_state/restore_warnings_state (e.g. just
|
||||||
# warnings.resetwarnings()) , but it is difficult to test more.
|
# warnings.resetwarnings()) , but it is difficult to test more.
|
||||||
import warnings
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", DeprecationWarning)
|
warnings.simplefilter("ignore", DeprecationWarning)
|
||||||
|
|
||||||
|
@ -245,7 +302,6 @@ class SaveRestoreWarningState(TestCase):
|
||||||
|
|
||||||
class HTMLEqualTests(TestCase):
|
class HTMLEqualTests(TestCase):
|
||||||
def test_html_parser(self):
|
def test_html_parser(self):
|
||||||
from django.test.html import parse_html
|
|
||||||
element = parse_html('<div><p>Hello</p></div>')
|
element = parse_html('<div><p>Hello</p></div>')
|
||||||
self.assertEqual(len(element.children), 1)
|
self.assertEqual(len(element.children), 1)
|
||||||
self.assertEqual(element.children[0].name, 'p')
|
self.assertEqual(element.children[0].name, 'p')
|
||||||
|
@ -259,7 +315,6 @@ class HTMLEqualTests(TestCase):
|
||||||
self.assertEqual(dom[0], 'foo')
|
self.assertEqual(dom[0], 'foo')
|
||||||
|
|
||||||
def test_parse_html_in_script(self):
|
def test_parse_html_in_script(self):
|
||||||
from django.test.html import parse_html
|
|
||||||
parse_html('<script>var a = "<p" + ">";</script>');
|
parse_html('<script>var a = "<p" + ">";</script>');
|
||||||
parse_html('''
|
parse_html('''
|
||||||
<script>
|
<script>
|
||||||
|
@ -275,8 +330,6 @@ class HTMLEqualTests(TestCase):
|
||||||
self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>")
|
self.assertEqual(dom.children[0], "<p>foo</p> '</scr'+'ipt>' <span>bar</span>")
|
||||||
|
|
||||||
def test_self_closing_tags(self):
|
def test_self_closing_tags(self):
|
||||||
from django.test.html import parse_html
|
|
||||||
|
|
||||||
self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer',
|
self_closing_tags = ('br' , 'hr', 'input', 'img', 'meta', 'spacer',
|
||||||
'link', 'frame', 'base', 'col')
|
'link', 'frame', 'base', 'col')
|
||||||
for tag in self_closing_tags:
|
for tag in self_closing_tags:
|
||||||
|
@ -400,7 +453,6 @@ class HTMLEqualTests(TestCase):
|
||||||
</html>""")
|
</html>""")
|
||||||
|
|
||||||
def test_html_contain(self):
|
def test_html_contain(self):
|
||||||
from django.test.html import parse_html
|
|
||||||
# equal html contains each other
|
# equal html contains each other
|
||||||
dom1 = parse_html('<p>foo')
|
dom1 = parse_html('<p>foo')
|
||||||
dom2 = parse_html('<p>foo</p>')
|
dom2 = parse_html('<p>foo</p>')
|
||||||
|
@ -424,7 +476,6 @@ class HTMLEqualTests(TestCase):
|
||||||
self.assertTrue(dom1 in dom2)
|
self.assertTrue(dom1 in dom2)
|
||||||
|
|
||||||
def test_count(self):
|
def test_count(self):
|
||||||
from django.test.html import parse_html
|
|
||||||
# equal html contains each other one time
|
# equal html contains each other one time
|
||||||
dom1 = parse_html('<p>foo')
|
dom1 = parse_html('<p>foo')
|
||||||
dom2 = parse_html('<p>foo</p>')
|
dom2 = parse_html('<p>foo</p>')
|
||||||
|
@ -459,7 +510,6 @@ class HTMLEqualTests(TestCase):
|
||||||
self.assertEqual(dom2.count(dom1), 0)
|
self.assertEqual(dom2.count(dom1), 0)
|
||||||
|
|
||||||
def test_parsing_errors(self):
|
def test_parsing_errors(self):
|
||||||
from django.test.html import HTMLParseError, parse_html
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
self.assertHTMLEqual('<p>', '')
|
self.assertHTMLEqual('<p>', '')
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
|
@ -488,7 +538,6 @@ class HTMLEqualTests(TestCase):
|
||||||
self.assertContains(response, '<p "whats" that>')
|
self.assertContains(response, '<p "whats" that>')
|
||||||
|
|
||||||
def test_unicode_handling(self):
|
def test_unicode_handling(self):
|
||||||
from django.http import HttpResponse
|
|
||||||
response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>')
|
response = HttpResponse('<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>')
|
||||||
self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True)
|
self.assertContains(response, '<p class="help">Some help text for the title (with unicode ŠĐĆŽćžšđ)</p>', html=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue