mirror of https://github.com/django/django.git
Fixed #5416 -- Added TestCase.assertNumQueries, which tests that a given function executes the correct number of queries.
git-svn-id: http://code.djangoproject.com/svn/django/trunk@14183 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
parent
ceef628c19
commit
5506653b77
|
@ -21,6 +21,7 @@ class BaseDatabaseWrapper(local):
|
||||||
self.settings_dict = settings_dict
|
self.settings_dict = settings_dict
|
||||||
self.alias = alias
|
self.alias = alias
|
||||||
self.vendor = 'unknown'
|
self.vendor = 'unknown'
|
||||||
|
self.use_debug_cursor = None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.settings_dict == other.settings_dict
|
return self.settings_dict == other.settings_dict
|
||||||
|
@ -74,7 +75,8 @@ class BaseDatabaseWrapper(local):
|
||||||
def cursor(self):
|
def cursor(self):
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
cursor = self._cursor()
|
cursor = self._cursor()
|
||||||
if settings.DEBUG:
|
if (self.use_debug_cursor or
|
||||||
|
(self.use_debug_cursor is None and settings.DEBUG)):
|
||||||
return self.make_debug_cursor(cursor)
|
return self.make_debug_cursor(cursor)
|
||||||
return cursor
|
return cursor
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
from urlparse import urlsplit, urlunsplit
|
from urlparse import urlsplit, urlunsplit
|
||||||
from xml.dom.minidom import parseString, Node
|
from xml.dom.minidom import parseString, Node
|
||||||
|
|
||||||
|
@ -205,6 +206,33 @@ class DocTestRunner(doctest.DocTestRunner):
|
||||||
for conn in connections:
|
for conn in connections:
|
||||||
transaction.rollback_unless_managed(using=conn)
|
transaction.rollback_unless_managed(using=conn)
|
||||||
|
|
||||||
|
class _AssertNumQueriesContext(object):
|
||||||
|
def __init__(self, test_case, num, connection):
|
||||||
|
self.test_case = test_case
|
||||||
|
self.num = num
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.old_debug_cursor = self.connection.use_debug_cursor
|
||||||
|
self.connection.use_debug_cursor = True
|
||||||
|
self.starting_queries = len(self.connection.queries)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
if exc_type is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.connection.use_debug_cursor = self.old_debug_cursor
|
||||||
|
final_queries = len(self.connection.queries)
|
||||||
|
executed = final_queries - self.starting_queries
|
||||||
|
|
||||||
|
self.test_case.assertEqual(
|
||||||
|
executed, self.num, "%d queries executed, %d expected" % (
|
||||||
|
executed, self.num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransactionTestCase(unittest.TestCase):
|
class TransactionTestCase(unittest.TestCase):
|
||||||
# The class we'll use for the test client self.client.
|
# The class we'll use for the test client self.client.
|
||||||
# Can be overridden in derived classes.
|
# Can be overridden in derived classes.
|
||||||
|
@ -469,6 +497,22 @@ class TransactionTestCase(unittest.TestCase):
|
||||||
def assertQuerysetEqual(self, qs, values, transform=repr):
|
def assertQuerysetEqual(self, qs, values, transform=repr):
|
||||||
return self.assertEqual(map(transform, qs), values)
|
return self.assertEqual(map(transform, qs), values)
|
||||||
|
|
||||||
|
def assertNumQueries(self, num, func=None, *args, **kwargs):
|
||||||
|
using = kwargs.pop("using", DEFAULT_DB_ALIAS)
|
||||||
|
connection = connections[using]
|
||||||
|
|
||||||
|
context = _AssertNumQueriesContext(self, num, connection)
|
||||||
|
if func is None:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Basically emulate the `with` statement here.
|
||||||
|
|
||||||
|
context.__enter__()
|
||||||
|
try:
|
||||||
|
func(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
context.__exit__(*sys.exc_info())
|
||||||
|
|
||||||
def connections_support_transactions():
|
def connections_support_transactions():
|
||||||
"""
|
"""
|
||||||
Returns True if all connections support transactions. This is messy
|
Returns True if all connections support transactions. This is messy
|
||||||
|
|
|
@ -1372,6 +1372,32 @@ cause of an failure in your test suite.
|
||||||
implicit ordering, you will need to apply a ``order_by()`` clause to your
|
implicit ordering, you will need to apply a ``order_by()`` clause to your
|
||||||
queryset to ensure that the test will pass reliably.
|
queryset to ensure that the test will pass reliably.
|
||||||
|
|
||||||
|
.. method:: TestCase.assertNumQueries(num, func, *args, **kwargs):
|
||||||
|
|
||||||
|
.. versionadded:: 1.3
|
||||||
|
|
||||||
|
Asserts that when ``func`` is called with ``*args`` and ``**kwargs`` that
|
||||||
|
``num`` database queries are executed.
|
||||||
|
|
||||||
|
If a ``"using"`` key is present in ``kwargs`` it is used as the database
|
||||||
|
alias for which to check the number of queries. If you wish to call a
|
||||||
|
function with a ``using`` parameter you can do it by wrapping the call with
|
||||||
|
a ``lambda`` to add an extra parameter::
|
||||||
|
|
||||||
|
self.assertNumQueries(7, lambda: my_function(using=7))
|
||||||
|
|
||||||
|
If you're using Python 2.5 or greater you can also use this as a context
|
||||||
|
manager::
|
||||||
|
|
||||||
|
# This is necessary in Python 2.5 to enable the with statement, in 2.6
|
||||||
|
# and up it is no longer necessary.
|
||||||
|
from __future__ import with_statement
|
||||||
|
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
Person.objects.create(name="Aaron")
|
||||||
|
Person.objects.create(name="Daniel")
|
||||||
|
|
||||||
|
|
||||||
.. _topics-testing-email:
|
.. _topics-testing-email:
|
||||||
|
|
||||||
E-mail services
|
E-mail services
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.conf import settings
|
|
||||||
from django import db
|
|
||||||
|
|
||||||
from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
|
from models import Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species
|
||||||
|
|
||||||
|
@ -36,73 +34,73 @@ class SelectRelatedTests(TestCase):
|
||||||
# queries so we'll set it to True here and reset it at the end of the
|
# queries so we'll set it to True here and reset it at the end of the
|
||||||
# test case.
|
# test case.
|
||||||
self.create_base_data()
|
self.create_base_data()
|
||||||
settings.DEBUG = True
|
|
||||||
db.reset_queries()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
settings.DEBUG = False
|
|
||||||
|
|
||||||
def test_access_fks_without_select_related(self):
|
def test_access_fks_without_select_related(self):
|
||||||
"""
|
"""
|
||||||
Normally, accessing FKs doesn't fill in related objects
|
Normally, accessing FKs doesn't fill in related objects
|
||||||
"""
|
"""
|
||||||
fly = Species.objects.get(name="melanogaster")
|
def test():
|
||||||
domain = fly.genus.family.order.klass.phylum.kingdom.domain
|
fly = Species.objects.get(name="melanogaster")
|
||||||
self.assertEqual(domain.name, 'Eukaryota')
|
domain = fly.genus.family.order.klass.phylum.kingdom.domain
|
||||||
self.assertEqual(len(db.connection.queries), 8)
|
self.assertEqual(domain.name, 'Eukaryota')
|
||||||
|
self.assertNumQueries(8, test)
|
||||||
|
|
||||||
def test_access_fks_with_select_related(self):
|
def test_access_fks_with_select_related(self):
|
||||||
"""
|
"""
|
||||||
A select_related() call will fill in those related objects without any
|
A select_related() call will fill in those related objects without any
|
||||||
extra queries
|
extra queries
|
||||||
"""
|
"""
|
||||||
person = Species.objects.select_related(depth=10).get(name="sapiens")
|
def test():
|
||||||
domain = person.genus.family.order.klass.phylum.kingdom.domain
|
person = Species.objects.select_related(depth=10).get(name="sapiens")
|
||||||
self.assertEqual(domain.name, 'Eukaryota')
|
domain = person.genus.family.order.klass.phylum.kingdom.domain
|
||||||
self.assertEqual(len(db.connection.queries), 1)
|
self.assertEqual(domain.name, 'Eukaryota')
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_list_without_select_related(self):
|
def test_list_without_select_related(self):
|
||||||
"""
|
"""
|
||||||
select_related() also of course applies to entire lists, not just
|
select_related() also of course applies to entire lists, not just
|
||||||
items. This test verifies the expected behavior without select_related.
|
items. This test verifies the expected behavior without select_related.
|
||||||
"""
|
"""
|
||||||
world = Species.objects.all()
|
def test():
|
||||||
families = [o.genus.family.name for o in world]
|
world = Species.objects.all()
|
||||||
self.assertEqual(families, [
|
families = [o.genus.family.name for o in world]
|
||||||
'Drosophilidae',
|
self.assertEqual(families, [
|
||||||
'Hominidae',
|
'Drosophilidae',
|
||||||
'Fabaceae',
|
'Hominidae',
|
||||||
'Amanitacae',
|
'Fabaceae',
|
||||||
])
|
'Amanitacae',
|
||||||
self.assertEqual(len(db.connection.queries), 9)
|
])
|
||||||
|
self.assertNumQueries(9, test)
|
||||||
|
|
||||||
def test_list_with_select_related(self):
|
def test_list_with_select_related(self):
|
||||||
"""
|
"""
|
||||||
select_related() also of course applies to entire lists, not just
|
select_related() also of course applies to entire lists, not just
|
||||||
items. This test verifies the expected behavior with select_related.
|
items. This test verifies the expected behavior with select_related.
|
||||||
"""
|
"""
|
||||||
world = Species.objects.all().select_related()
|
def test():
|
||||||
families = [o.genus.family.name for o in world]
|
world = Species.objects.all().select_related()
|
||||||
self.assertEqual(families, [
|
families = [o.genus.family.name for o in world]
|
||||||
'Drosophilidae',
|
self.assertEqual(families, [
|
||||||
'Hominidae',
|
'Drosophilidae',
|
||||||
'Fabaceae',
|
'Hominidae',
|
||||||
'Amanitacae',
|
'Fabaceae',
|
||||||
])
|
'Amanitacae',
|
||||||
self.assertEqual(len(db.connection.queries), 1)
|
])
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_depth(self, depth=1, expected=7):
|
def test_depth(self, depth=1, expected=7):
|
||||||
"""
|
"""
|
||||||
The "depth" argument to select_related() will stop the descent at a
|
The "depth" argument to select_related() will stop the descent at a
|
||||||
particular level.
|
particular level.
|
||||||
"""
|
"""
|
||||||
pea = Species.objects.select_related(depth=depth).get(name="sativum")
|
def test():
|
||||||
self.assertEqual(
|
pea = Species.objects.select_related(depth=depth).get(name="sativum")
|
||||||
pea.genus.family.order.klass.phylum.kingdom.domain.name,
|
self.assertEqual(
|
||||||
'Eukaryota'
|
pea.genus.family.order.klass.phylum.kingdom.domain.name,
|
||||||
)
|
'Eukaryota'
|
||||||
|
)
|
||||||
# Notice: one fewer queries than above because of depth=1
|
# Notice: one fewer queries than above because of depth=1
|
||||||
self.assertEqual(len(db.connection.queries), expected)
|
self.assertNumQueries(expected, test)
|
||||||
|
|
||||||
def test_larger_depth(self):
|
def test_larger_depth(self):
|
||||||
"""
|
"""
|
||||||
|
@ -116,11 +114,12 @@ class SelectRelatedTests(TestCase):
|
||||||
The "depth" argument to select_related() will stop the descent at a
|
The "depth" argument to select_related() will stop the descent at a
|
||||||
particular level. This can be used on lists as well.
|
particular level. This can be used on lists as well.
|
||||||
"""
|
"""
|
||||||
world = Species.objects.all().select_related(depth=2)
|
def test():
|
||||||
orders = [o.genus.family.order.name for o in world]
|
world = Species.objects.all().select_related(depth=2)
|
||||||
self.assertEqual(orders,
|
orders = [o.genus.family.order.name for o in world]
|
||||||
['Diptera', 'Primates', 'Fabales', 'Agaricales'])
|
self.assertEqual(orders,
|
||||||
self.assertEqual(len(db.connection.queries), 5)
|
['Diptera', 'Primates', 'Fabales', 'Agaricales'])
|
||||||
|
self.assertNumQueries(5, test)
|
||||||
|
|
||||||
def test_select_related_with_extra(self):
|
def test_select_related_with_extra(self):
|
||||||
s = Species.objects.all().select_related(depth=1)\
|
s = Species.objects.all().select_related(depth=1)\
|
||||||
|
@ -136,28 +135,31 @@ class SelectRelatedTests(TestCase):
|
||||||
In this case, we explicitly say to select the 'genus' and
|
In this case, we explicitly say to select the 'genus' and
|
||||||
'genus.family' models, leading to the same number of queries as before.
|
'genus.family' models, leading to the same number of queries as before.
|
||||||
"""
|
"""
|
||||||
world = Species.objects.select_related('genus__family')
|
def test():
|
||||||
families = [o.genus.family.name for o in world]
|
world = Species.objects.select_related('genus__family')
|
||||||
self.assertEqual(families,
|
families = [o.genus.family.name for o in world]
|
||||||
['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
|
self.assertEqual(families,
|
||||||
self.assertEqual(len(db.connection.queries), 1)
|
['Drosophilidae', 'Hominidae', 'Fabaceae', 'Amanitacae'])
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_more_certain_fields(self):
|
def test_more_certain_fields(self):
|
||||||
"""
|
"""
|
||||||
In this case, we explicitly say to select the 'genus' and
|
In this case, we explicitly say to select the 'genus' and
|
||||||
'genus.family' models, leading to the same number of queries as before.
|
'genus.family' models, leading to the same number of queries as before.
|
||||||
"""
|
"""
|
||||||
world = Species.objects.filter(genus__name='Amanita')\
|
def test():
|
||||||
.select_related('genus__family')
|
world = Species.objects.filter(genus__name='Amanita')\
|
||||||
orders = [o.genus.family.order.name for o in world]
|
.select_related('genus__family')
|
||||||
self.assertEqual(orders, [u'Agaricales'])
|
orders = [o.genus.family.order.name for o in world]
|
||||||
self.assertEqual(len(db.connection.queries), 2)
|
self.assertEqual(orders, [u'Agaricales'])
|
||||||
|
self.assertNumQueries(2, test)
|
||||||
|
|
||||||
def test_field_traversal(self):
|
def test_field_traversal(self):
|
||||||
s = Species.objects.all().select_related('genus__family__order'
|
def test():
|
||||||
).order_by('id')[0:1].get().genus.family.order.name
|
s = Species.objects.all().select_related('genus__family__order'
|
||||||
self.assertEqual(s, u'Diptera')
|
).order_by('id')[0:1].get().genus.family.order.name
|
||||||
self.assertEqual(len(db.connection.queries), 1)
|
self.assertEqual(s, u'Diptera')
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_depth_fields_fails(self):
|
def test_depth_fields_fails(self):
|
||||||
self.assertRaises(TypeError,
|
self.assertRaises(TypeError,
|
||||||
|
|
|
@ -2,9 +2,11 @@ import datetime
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
|
from django.test import TestCase
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
|
|
||||||
from models import CustomPKModel, UniqueTogetherModel, UniqueFieldsModel, UniqueForDateModel, ModelToValidate
|
from models import (CustomPKModel, UniqueTogetherModel, UniqueFieldsModel,
|
||||||
|
UniqueForDateModel, ModelToValidate)
|
||||||
|
|
||||||
|
|
||||||
class GetUniqueCheckTests(unittest.TestCase):
|
class GetUniqueCheckTests(unittest.TestCase):
|
||||||
|
@ -51,37 +53,26 @@ class GetUniqueCheckTests(unittest.TestCase):
|
||||||
), m._get_unique_checks(exclude='start_date')
|
), m._get_unique_checks(exclude='start_date')
|
||||||
)
|
)
|
||||||
|
|
||||||
class PerformUniqueChecksTest(unittest.TestCase):
|
class PerformUniqueChecksTest(TestCase):
|
||||||
def setUp(self):
|
|
||||||
# Set debug to True to gain access to connection.queries.
|
|
||||||
self._old_debug, settings.DEBUG = settings.DEBUG, True
|
|
||||||
super(PerformUniqueChecksTest, self).setUp()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
# Restore old debug value.
|
|
||||||
settings.DEBUG = self._old_debug
|
|
||||||
super(PerformUniqueChecksTest, self).tearDown()
|
|
||||||
|
|
||||||
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
|
def test_primary_key_unique_check_not_performed_when_adding_and_pk_not_specified(self):
|
||||||
# Regression test for #12560
|
# Regression test for #12560
|
||||||
query_count = len(connection.queries)
|
def test():
|
||||||
mtv = ModelToValidate(number=10, name='Some Name')
|
mtv = ModelToValidate(number=10, name='Some Name')
|
||||||
setattr(mtv, '_adding', True)
|
setattr(mtv, '_adding', True)
|
||||||
mtv.full_clean()
|
mtv.full_clean()
|
||||||
self.assertEqual(query_count, len(connection.queries))
|
self.assertNumQueries(0, test)
|
||||||
|
|
||||||
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
|
def test_primary_key_unique_check_performed_when_adding_and_pk_specified(self):
|
||||||
# Regression test for #12560
|
# Regression test for #12560
|
||||||
query_count = len(connection.queries)
|
def test():
|
||||||
mtv = ModelToValidate(number=10, name='Some Name', id=123)
|
mtv = ModelToValidate(number=10, name='Some Name', id=123)
|
||||||
setattr(mtv, '_adding', True)
|
setattr(mtv, '_adding', True)
|
||||||
mtv.full_clean()
|
mtv.full_clean()
|
||||||
self.assertEqual(query_count + 1, len(connection.queries))
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_primary_key_unique_check_not_performed_when_not_adding(self):
|
def test_primary_key_unique_check_not_performed_when_not_adding(self):
|
||||||
# Regression test for #12132
|
# Regression test for #12132
|
||||||
query_count= len(connection.queries)
|
def test():
|
||||||
mtv = ModelToValidate(number=10, name='Some Name')
|
mtv = ModelToValidate(number=10, name='Some Name')
|
||||||
mtv.full_clean()
|
mtv.full_clean()
|
||||||
self.assertEqual(query_count, len(connection.queries))
|
self.assertNumQueries(0, test)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,8 @@ from modeltests.validation.models import Author, Article, ModelToValidate
|
||||||
|
|
||||||
# Import other tests for this package.
|
# Import other tests for this package.
|
||||||
from modeltests.validation.validators import TestModelsWithValidators
|
from modeltests.validation.validators import TestModelsWithValidators
|
||||||
from modeltests.validation.test_unique import GetUniqueCheckTests, PerformUniqueChecksTest
|
from modeltests.validation.test_unique import (GetUniqueCheckTests,
|
||||||
|
PerformUniqueChecksTest)
|
||||||
from modeltests.validation.test_custom_messages import CustomMessagesTest
|
from modeltests.validation.test_custom_messages import CustomMessagesTest
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,4 +112,3 @@ class ModelFormsTests(TestCase):
|
||||||
article = Article(author_id=self.author.id)
|
article = Article(author_id=self.author.id)
|
||||||
form = ArticleForm(data, instance=article)
|
form = ArticleForm(data, instance=article)
|
||||||
self.assertEqual(form.errors.keys(), ['pub_date'])
|
self.assertEqual(form.errors.keys(), ['pub_date'])
|
||||||
|
|
||||||
|
|
|
@ -11,17 +11,6 @@ from models import ResolveThis, Item, RelatedItem, Child, Leaf
|
||||||
|
|
||||||
|
|
||||||
class DeferRegressionTest(TestCase):
|
class DeferRegressionTest(TestCase):
|
||||||
def assert_num_queries(self, n, func, *args, **kwargs):
|
|
||||||
old_DEBUG = settings.DEBUG
|
|
||||||
settings.DEBUG = True
|
|
||||||
starting_queries = len(connection.queries)
|
|
||||||
try:
|
|
||||||
func(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
settings.DEBUG = old_DEBUG
|
|
||||||
self.assertEqual(starting_queries + n, len(connection.queries))
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
# Deferred fields should really be deferred and not accidentally use
|
# Deferred fields should really be deferred and not accidentally use
|
||||||
# the field's default value just because they aren't passed to __init__
|
# the field's default value just because they aren't passed to __init__
|
||||||
|
@ -33,19 +22,19 @@ class DeferRegressionTest(TestCase):
|
||||||
def test():
|
def test():
|
||||||
self.assertEqual(obj.name, "first")
|
self.assertEqual(obj.name, "first")
|
||||||
self.assertEqual(obj.other_value, 0)
|
self.assertEqual(obj.other_value, 0)
|
||||||
self.assert_num_queries(0, test)
|
self.assertNumQueries(0, test)
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
self.assertEqual(obj.value, 42)
|
self.assertEqual(obj.value, 42)
|
||||||
self.assert_num_queries(1, test)
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
self.assertEqual(obj.text, "xyzzy")
|
self.assertEqual(obj.text, "xyzzy")
|
||||||
self.assert_num_queries(1, test)
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
self.assertEqual(obj.text, "xyzzy")
|
self.assertEqual(obj.text, "xyzzy")
|
||||||
self.assert_num_queries(0, test)
|
self.assertNumQueries(0, test)
|
||||||
|
|
||||||
# Regression test for #10695. Make sure different instances don't
|
# Regression test for #10695. Make sure different instances don't
|
||||||
# inadvertently share data in the deferred descriptor objects.
|
# inadvertently share data in the deferred descriptor objects.
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import datetime
|
import datetime
|
||||||
import tempfile
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from django.db import models, connection
|
from django.db import models
|
||||||
from django.conf import settings
|
|
||||||
# Can't import as "forms" due to implementation details in the test suite (the
|
# Can't import as "forms" due to implementation details in the test suite (the
|
||||||
# current file is called "forms" and is already imported).
|
# current file is called "forms" and is already imported).
|
||||||
from django import forms as django_forms
|
from django import forms as django_forms
|
||||||
|
@ -77,19 +76,13 @@ class TestTicket12510(TestCase):
|
||||||
''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). '''
|
''' It is not necessary to generate choices for ModelChoiceField (regression test for #12510). '''
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.groups = [Group.objects.create(name=name) for name in 'abc']
|
self.groups = [Group.objects.create(name=name) for name in 'abc']
|
||||||
self.old_debug = settings.DEBUG
|
|
||||||
# turn debug on to get access to connection.queries
|
|
||||||
settings.DEBUG = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
settings.DEBUG = self.old_debug
|
|
||||||
|
|
||||||
def test_choices_not_fetched_when_not_rendering(self):
|
def test_choices_not_fetched_when_not_rendering(self):
|
||||||
initial_queries = len(connection.queries)
|
def test():
|
||||||
field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
|
field = django_forms.ModelChoiceField(Group.objects.order_by('-name'))
|
||||||
self.assertEqual('a', field.clean(self.groups[0].pk).name)
|
self.assertEqual('a', field.clean(self.groups[0].pk).name)
|
||||||
# only one query is required to pull the model from DB
|
# only one query is required to pull the model from DB
|
||||||
self.assertEqual(initial_queries+1, len(connection.queries))
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
class ModelFormCallableModelDefault(TestCase):
|
class ModelFormCallableModelDefault(TestCase):
|
||||||
def test_no_empty_option(self):
|
def test_no_empty_option(self):
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import unittest
|
import unittest
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
||||||
from django import db
|
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.forms.models import modelform_factory, ModelChoiceField
|
from django.forms.models import modelform_factory, ModelChoiceField
|
||||||
from django.conf import settings
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.core.exceptions import FieldError, ValidationError
|
from django.core.exceptions import FieldError, ValidationError
|
||||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||||
|
@ -14,14 +12,6 @@ from models import Person, RealPerson, Triple, FilePathModel, Article, \
|
||||||
|
|
||||||
|
|
||||||
class ModelMultipleChoiceFieldTests(TestCase):
|
class ModelMultipleChoiceFieldTests(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.old_debug = settings.DEBUG
|
|
||||||
settings.DEBUG = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
settings.DEBUG = self.old_debug
|
|
||||||
|
|
||||||
def test_model_multiple_choice_number_of_queries(self):
|
def test_model_multiple_choice_number_of_queries(self):
|
||||||
"""
|
"""
|
||||||
Test that ModelMultipleChoiceField does O(1) queries instead of
|
Test that ModelMultipleChoiceField does O(1) queries instead of
|
||||||
|
@ -30,10 +20,8 @@ class ModelMultipleChoiceFieldTests(TestCase):
|
||||||
for i in range(30):
|
for i in range(30):
|
||||||
Person.objects.create(name="Person %s" % i)
|
Person.objects.create(name="Person %s" % i)
|
||||||
|
|
||||||
db.reset_queries()
|
|
||||||
f = forms.ModelMultipleChoiceField(queryset=Person.objects.all())
|
f = forms.ModelMultipleChoiceField(queryset=Person.objects.all())
|
||||||
selected = f.clean([1, 3, 5, 7, 9])
|
self.assertNumQueries(1, f.clean, [1, 3, 5, 7, 9])
|
||||||
self.assertEquals(len(db.connection.queries), 1)
|
|
||||||
|
|
||||||
class TripleForm(forms.ModelForm):
|
class TripleForm(forms.ModelForm):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
|
@ -7,11 +7,6 @@ from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
|
||||||
|
|
||||||
class ReverseSelectRelatedTestCase(TestCase):
|
class ReverseSelectRelatedTestCase(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Explicitly enable debug for these tests - we need to count
|
|
||||||
# the queries that have been issued.
|
|
||||||
self.old_debug = settings.DEBUG
|
|
||||||
settings.DEBUG = True
|
|
||||||
|
|
||||||
user = User.objects.create(username="test")
|
user = User.objects.create(username="test")
|
||||||
userprofile = UserProfile.objects.create(user=user, state="KS",
|
userprofile = UserProfile.objects.create(user=user, state="KS",
|
||||||
city="Lawrence")
|
city="Lawrence")
|
||||||
|
@ -26,65 +21,66 @@ class ReverseSelectRelatedTestCase(TestCase):
|
||||||
results=results2)
|
results=results2)
|
||||||
StatDetails.objects.create(base_stats=advstat, comments=250)
|
StatDetails.objects.create(base_stats=advstat, comments=250)
|
||||||
|
|
||||||
db.reset_queries()
|
|
||||||
|
|
||||||
def assertQueries(self, queries):
|
|
||||||
self.assertEqual(len(db.connection.queries), queries)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
settings.DEBUG = self.old_debug
|
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
u = User.objects.select_related("userprofile").get(username="test")
|
def test():
|
||||||
self.assertEqual(u.userprofile.state, "KS")
|
u = User.objects.select_related("userprofile").get(username="test")
|
||||||
self.assertQueries(1)
|
self.assertEqual(u.userprofile.state, "KS")
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_follow_next_level(self):
|
def test_follow_next_level(self):
|
||||||
u = User.objects.select_related("userstat__results").get(username="test")
|
def test():
|
||||||
self.assertEqual(u.userstat.posts, 150)
|
u = User.objects.select_related("userstat__results").get(username="test")
|
||||||
self.assertEqual(u.userstat.results.results, 'first results')
|
self.assertEqual(u.userstat.posts, 150)
|
||||||
self.assertQueries(1)
|
self.assertEqual(u.userstat.results.results, 'first results')
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_follow_two(self):
|
def test_follow_two(self):
|
||||||
u = User.objects.select_related("userprofile", "userstat").get(username="test")
|
def test():
|
||||||
self.assertEqual(u.userprofile.state, "KS")
|
u = User.objects.select_related("userprofile", "userstat").get(username="test")
|
||||||
self.assertEqual(u.userstat.posts, 150)
|
self.assertEqual(u.userprofile.state, "KS")
|
||||||
self.assertQueries(1)
|
self.assertEqual(u.userstat.posts, 150)
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_follow_two_next_level(self):
|
def test_follow_two_next_level(self):
|
||||||
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
|
def test():
|
||||||
self.assertEqual(u.userstat.results.results, 'first results')
|
u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test")
|
||||||
self.assertEqual(u.userstat.statdetails.comments, 259)
|
self.assertEqual(u.userstat.results.results, 'first results')
|
||||||
self.assertQueries(1)
|
self.assertEqual(u.userstat.statdetails.comments, 259)
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_forward_and_back(self):
|
def test_forward_and_back(self):
|
||||||
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
|
def test():
|
||||||
self.assertEqual(stat.user.userprofile.state, 'KS')
|
stat = UserStat.objects.select_related("user__userprofile").get(user__username="test")
|
||||||
self.assertEqual(stat.user.userstat.posts, 150)
|
self.assertEqual(stat.user.userprofile.state, 'KS')
|
||||||
self.assertQueries(1)
|
self.assertEqual(stat.user.userstat.posts, 150)
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_back_and_forward(self):
|
def test_back_and_forward(self):
|
||||||
u = User.objects.select_related("userstat").get(username="test")
|
def test():
|
||||||
self.assertEqual(u.userstat.user.username, 'test')
|
u = User.objects.select_related("userstat").get(username="test")
|
||||||
self.assertQueries(1)
|
self.assertEqual(u.userstat.user.username, 'test')
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_not_followed_by_default(self):
|
def test_not_followed_by_default(self):
|
||||||
u = User.objects.select_related().get(username="test")
|
def test():
|
||||||
self.assertEqual(u.userstat.posts, 150)
|
u = User.objects.select_related().get(username="test")
|
||||||
self.assertQueries(2)
|
self.assertEqual(u.userstat.posts, 150)
|
||||||
|
self.assertNumQueries(2, test)
|
||||||
|
|
||||||
def test_follow_from_child_class(self):
|
def test_follow_from_child_class(self):
|
||||||
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
|
def test():
|
||||||
self.assertEqual(stat.statdetails.comments, 250)
|
stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200)
|
||||||
self.assertEqual(stat.user.username, 'bob')
|
self.assertEqual(stat.statdetails.comments, 250)
|
||||||
self.assertQueries(1)
|
self.assertEqual(stat.user.username, 'bob')
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_follow_inheritance(self):
|
def test_follow_inheritance(self):
|
||||||
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
|
def test():
|
||||||
self.assertEqual(stat.advanceduserstat.posts, 200)
|
stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200)
|
||||||
self.assertEqual(stat.user.username, 'bob')
|
self.assertEqual(stat.advanceduserstat.posts, 200)
|
||||||
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
|
self.assertEqual(stat.user.username, 'bob')
|
||||||
self.assertQueries(1)
|
self.assertEqual(stat.advanceduserstat.user.username, 'bob')
|
||||||
|
self.assertNumQueries(1, test)
|
||||||
|
|
||||||
def test_nullable_relation(self):
|
def test_nullable_relation(self):
|
||||||
im = Image.objects.create(name="imag1")
|
im = Image.objects.create(name="imag1")
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
from django.db import models
|
||||||
|
|
||||||
|
|
||||||
|
class Person(models.Model):
|
||||||
|
name = models.CharField(max_length=100)
|
|
@ -0,0 +1,30 @@
|
||||||
|
from __future__ import with_statement
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from models import Person
|
||||||
|
|
||||||
|
|
||||||
|
class AssertNumQueriesTests(TestCase):
|
||||||
|
def test_simple(self):
|
||||||
|
with self.assertNumQueries(0):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
# Guy who wrote Linux
|
||||||
|
Person.objects.create(name="Linus Torvalds")
|
||||||
|
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
# Guy who owns the bagel place I like
|
||||||
|
Person.objects.create(name="Uncle Ricky")
|
||||||
|
self.assertEqual(Person.objects.count(), 2)
|
||||||
|
|
||||||
|
def test_failure(self):
|
||||||
|
with self.assertRaises(AssertionError) as exc_info:
|
||||||
|
with self.assertNumQueries(2):
|
||||||
|
Person.objects.count()
|
||||||
|
self.assertEqual(str(exc_info.exception), "1 != 2 : 1 queries executed, 2 expected")
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
with self.assertNumQueries(4000):
|
||||||
|
raise TypeError
|
|
@ -1,4 +1,10 @@
|
||||||
r"""
|
import sys
|
||||||
|
|
||||||
|
if sys.version_info >= (2, 5):
|
||||||
|
from python_25 import AssertNumQueriesTests
|
||||||
|
|
||||||
|
|
||||||
|
__test__ = {"API_TEST": r"""
|
||||||
# Some checks of the doctest output normalizer.
|
# Some checks of the doctest output normalizer.
|
||||||
# Standard doctests do fairly
|
# Standard doctests do fairly
|
||||||
>>> from django.utils import simplejson
|
>>> from django.utils import simplejson
|
||||||
|
@ -69,4 +75,4 @@ r"""
|
||||||
>>> produce_xml_fragment()
|
>>> produce_xml_fragment()
|
||||||
'<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>'
|
'<foo bbb="2.0" aaa="1.0">Hello</foo><bar ddd="4.0" ccc="3.0"></bar>'
|
||||||
|
|
||||||
"""
|
"""}
|
||||||
|
|
Loading…
Reference in New Issue