diff --git a/AUTHORS b/AUTHORS index 75df381c8d..4dd4e6d0d9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -334,6 +334,7 @@ answer newbie questions, and generally made Django that much better: Massimiliano Ravelli Brian Ray remco@diji.biz + Marc Remolt David Reynolds rhettg@gmail.com ricardojbarrios@gmail.com diff --git a/django/db/backends/creation.py b/django/db/backends/creation.py index aba41fb3f9..b53e9f1007 100644 --- a/django/db/backends/creation.py +++ b/django/db/backends/creation.py @@ -311,7 +311,8 @@ class BaseDatabaseCreation(object): self.connection.close() settings.DATABASE_NAME = test_database_name - + settings.DATABASE_SUPPORTS_TRANSACTIONS = self._rollback_works() + call_command('syncdb', verbosity=verbosity, interactive=False) if settings.CACHE_BACKEND.startswith('db://'): @@ -362,7 +363,19 @@ class BaseDatabaseCreation(object): sys.exit(1) return test_database_name - + + def _rollback_works(self): + cursor = self.connection.cursor() + cursor.execute('CREATE TABLE ROLLBACK_TEST (X INT)') + self.connection._commit() + cursor.execute('INSERT INTO ROLLBACK_TEST (X) VALUES (8)') + self.connection._rollback() + cursor.execute('SELECT COUNT(X) FROM ROLLBACK_TEST') + count, = cursor.fetchone() + cursor.execute('DROP TABLE ROLLBACK_TEST') + self.connection._commit() + return count == 0 + def destroy_test_db(self, old_database_name, verbosity=1): """ Destroy a test database, prompting the user for confirmation if the diff --git a/django/test/__init__.py b/django/test/__init__.py index 554e72bad3..957b293e12 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -3,4 +3,4 @@ Django Unit Test and Doctest framework. """ from django.test.client import Client -from django.test.testcases import TestCase +from django.test.testcases import TestCase, TransactionTestCase diff --git a/django/test/client.py b/django/test/client.py index a4dc212e8d..d89b625a68 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -19,6 +19,7 @@ from django.utils.functional import curry from django.utils.encoding import smart_str from django.utils.http import urlencode from django.utils.itercompat import is_iterable +from django.db import transaction, close_connection BOUNDARY = 'BoUnDaRyStRiNg' MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY @@ -69,7 +70,9 @@ class ClientHandler(BaseHandler): response = middleware_method(request, response) response = self.apply_response_fixes(request, response) finally: + signals.request_finished.disconnect(close_connection) signals.request_finished.send(sender=self.__class__) + signals.request_finished.connect(close_connection) return response diff --git a/django/test/simple.py b/django/test/simple.py index ce9f59e90e..18ba063c58 100644 --- a/django/test/simple.py +++ b/django/test/simple.py @@ -3,7 +3,7 @@ from django.conf import settings from django.db.models import get_app, get_apps from django.test import _doctest as doctest from django.test.utils import setup_test_environment, teardown_test_environment -from django.test.testcases import OutputChecker, DocTestRunner +from django.test.testcases import OutputChecker, DocTestRunner, TestCase # The module name for tests outside models.py TEST_MODULE = 'tests' @@ -99,6 +99,43 @@ def build_test(label): else: # label is app.TestClass.test_method return TestClass(parts[2]) +def partition_suite(suite, classes, bins): + """ + Partitions a test suite by test type. + + classes is a sequence of types + bins is a sequence of TestSuites, one more than classes + + Tests of type classes[i] are added to bins[i], + tests with no match found in classes are place in bins[-1] + """ + for test in suite: + if isinstance(test, unittest.TestSuite): + partition_suite(test, classes, bins) + else: + for i in range(len(classes)): + if isinstance(test, classes[i]): + bins[i].addTest(test) + break + else: + bins[-1].addTest(test) + +def reorder_suite(suite, classes): + """ + Reorders a test suite by test type. + + classes is a sequence of types + + All tests of type clases[0] are placed first, then tests of type classes[1], etc. + Tests with no match in classes are placed last. + """ + class_count = len(classes) + bins = [unittest.TestSuite() for i in range(class_count+1)] + partition_suite(suite, classes, bins) + for i in range(class_count): + bins[0].addTests(bins[i+1]) + return bins[0] + def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]): """ Run the unit tests for all the test labels in the provided list. @@ -137,6 +174,8 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]): for test in extra_tests: suite.addTest(test) + suite = reorder_suite(suite, (TestCase,)) + old_name = settings.DATABASE_NAME from django.db import connection connection.creation.create_test_db(verbosity, autoclobber=not interactive) diff --git a/django/test/testcases.py b/django/test/testcases.py index 81e14a0a14..eed252b8cf 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -7,7 +7,7 @@ from django.conf import settings from django.core import mail from django.core.management import call_command from django.core.urlresolvers import clear_url_caches -from django.db import transaction +from django.db import transaction, connection from django.http import QueryDict from django.test import _doctest as doctest from django.test.client import Client @@ -27,6 +27,31 @@ def to_list(value): value = [value] return value +real_commit = transaction.commit +real_rollback = transaction.rollback +real_enter_transaction_management = transaction.enter_transaction_management +real_leave_transaction_management = transaction.leave_transaction_management +real_savepoint_commit = transaction.savepoint_commit +real_savepoint_rollback = transaction.savepoint_rollback + +def nop(x=None): + return + +def disable_transaction_methods(): + transaction.commit = nop + transaction.rollback = nop + transaction.savepoint_commit = nop + transaction.savepoint_rollback = nop + transaction.enter_transaction_management = nop + transaction.leave_transaction_management = nop + +def restore_transaction_methods(): + transaction.commit = real_commit + transaction.rollback = real_rollback + transaction.savepoint_commit = real_savepoint_commit + transaction.savepoint_rollback = real_savepoint_rollback + transaction.enter_transaction_management = real_enter_transaction_management + transaction.leave_transaction_management = real_leave_transaction_management class OutputChecker(doctest.OutputChecker): def check_output(self, want, got, optionflags): @@ -173,8 +198,8 @@ class DocTestRunner(doctest.DocTestRunner): # Rollback, in case of database errors. Otherwise they'd have # side effects on other tests. transaction.rollback_unless_managed() - -class TestCase(unittest.TestCase): + +class TransactionTestCase(unittest.TestCase): def _pre_setup(self): """Performs any pre-test setup. This includes: @@ -185,16 +210,22 @@ class TestCase(unittest.TestCase): ROOT_URLCONF with it. * Clearing the mail test outbox. """ + self._fixture_setup() + self._urlconf_setup() + mail.outbox = [] + + def _fixture_setup(self): call_command('flush', verbosity=0, interactive=False) if hasattr(self, 'fixtures'): # We have to use this slightly awkward syntax due to the fact # that we're using *args and **kwargs together. call_command('loaddata', *self.fixtures, **{'verbosity': 0}) + + def _urlconf_setup(self): if hasattr(self, 'urls'): self._old_root_urlconf = settings.ROOT_URLCONF settings.ROOT_URLCONF = self.urls clear_url_caches() - mail.outbox = [] def __call__(self, result=None): """ @@ -211,7 +242,7 @@ class TestCase(unittest.TestCase): import sys result.addError(self, sys.exc_info()) return - super(TestCase, self).__call__(result) + super(TransactionTestCase, self).__call__(result) try: self._post_teardown() except (KeyboardInterrupt, SystemExit): @@ -226,6 +257,13 @@ class TestCase(unittest.TestCase): * Putting back the original ROOT_URLCONF if it was changed. """ + self._fixture_teardown() + self._urlconf_teardown() + + def _fixture_teardown(self): + pass + + def _urlconf_teardown(self): if hasattr(self, '_old_root_urlconf'): settings.ROOT_URLCONF = self._old_root_urlconf clear_url_caches() @@ -359,3 +397,37 @@ class TestCase(unittest.TestCase): self.failIf(template_name in template_names, (u"Template '%s' was used unexpectedly in rendering the" u" response") % template_name) + +class TestCase(TransactionTestCase): + """ + Does basically the same as TransactionTestCase, but surrounds every test + with a transaction, monkey-patches the real transaction management routines to + do nothing, and rollsback the test transaction at the end of the test. You have + to use TransactionTestCase, if you need transaction management inside a test. + """ + + def _fixture_setup(self): + if not settings.DATABASE_SUPPORTS_TRANSACTIONS: + return super(TestCase, self)._fixture_setup() + + transaction.enter_transaction_management() + transaction.managed(True) + disable_transaction_methods() + + from django.contrib.sites.models import Site + Site.objects.clear_cache() + + if hasattr(self, 'fixtures'): + call_command('loaddata', *self.fixtures, **{ + 'verbosity': 0, + 'commit': False + }) + + def _fixture_teardown(self): + if not settings.DATABASE_SUPPORTS_TRANSACTIONS: + return super(TestCase, self)._fixture_teardown() + + restore_transaction_methods() + transaction.rollback() + transaction.leave_transaction_management() + connection.close() \ No newline at end of file diff --git a/docs/topics/testing.txt b/docs/topics/testing.txt index 23ac2481e7..bd68f6ba7a 100644 --- a/docs/topics/testing.txt +++ b/docs/topics/testing.txt @@ -785,6 +785,52 @@ just change the base class of your test from ``unittest.TestCase`` to will continue to be available, but it will be augmented with some useful additions. +.. versionadded:: 1.1 + +.. class:: TransactionTestCase() + +Django ``TestCase`` classes make use of database transaction facilities, if +available, to speed up the process of resetting the database to a known state +at the beginning of each test. A consequence of this, however, is that the +effects of transaction commit and rollback cannot be tested by a Django +``TestCase`` class. If your test requires testing of such transactional +behavior, you should use a Django ``TransactionTestCase``. + +``TransactionTestCase`` and ``TestCase`` are identical except for the manner +in which the database is reset to a known state and the ability for test code +to test the effects of commit and rollback. A ``TranscationTestCase`` resets +the database before the test runs by truncating all tables and reloading +initial data. A ``TransactionTestCase`` may call commit and rollback and +observe the effects of these calls on the database. + +A ``TestCase``, on the other hand, does not truncate tables and reload initial +data at the beginning of a test. Instead, it encloses the test code in a +database transaction that is rolled back at the end of the test. It also +prevents the code under test from issuing any commit or rollback operations +on the database, to ensure that the rollback at the end of the test restores +the database to its initial state. In order to guarantee that all ``TestCase`` +code starts with a clean database, the Django test runner runs all ``TestCase`` +tests first, before any other tests (e.g. doctests) that may alter the +database without restoring it to its original state. + +When running on a database that does not support rollback (e.g. MySQL with the +MyISAM storage engine), ``TestCase`` falls back to initializing the database +by truncating tables and reloading initial data. + + +.. note:: + The ``TestCase`` use of rollback to un-do the effects of the test code + may reveal previously-undetected errors in test code. For example, + test code that assumes primary keys values will be assigned starting at + one may find that assumption no longer holds true when rollbacks instead + of table truncation are being used to reset the database. Similarly, + the reordering of tests so that all ``TestCase`` classes run first may + reveal unexpected dependencies on test case ordering. In such cases a + quick fix is to switch the ``TestCase`` to a ``TransactionTestCase``. + A better long-term fix, that allows the test to take advantage of the + speed benefit of ``TestCase``, is to fix the underlying test problem. + + Default test client ~~~~~~~~~~~~~~~~~~~ diff --git a/tests/regressiontests/comment_tests/tests/comment_view_tests.py b/tests/regressiontests/comment_tests/tests/comment_view_tests.py index 0c975116ef..312fab633f 100644 --- a/tests/regressiontests/comment_tests/tests/comment_view_tests.py +++ b/tests/regressiontests/comment_tests/tests/comment_view_tests.py @@ -1,3 +1,4 @@ +import re from django.conf import settings from django.contrib.auth.models import User from django.contrib.comments import signals @@ -5,6 +6,8 @@ from django.contrib.comments.models import Comment from regressiontests.comment_tests.models import Article from regressiontests.comment_tests.tests import CommentTestCase +post_redirect_re = re.compile(r'^http://testserver/posted/\?c=(?P\d+$)') + class CommentViewTests(CommentTestCase): def testPostCommentHTTPMethods(self): @@ -181,18 +184,26 @@ class CommentViewTests(CommentTestCase): a = Article.objects.get(pk=1) data = self.getValidData(a) response = self.client.post("/post/", data) - self.assertEqual(response["Location"], "http://testserver/posted/?c=1") - + location = response["Location"] + match = post_redirect_re.match(location) + self.failUnless(match != None, "Unexpected redirect location: %s" % location) + data["next"] = "/somewhere/else/" data["comment"] = "This is another comment" response = self.client.post("/post/", data) - self.assertEqual(response["Location"], "http://testserver/somewhere/else/?c=2") + location = response["Location"] + match = re.search(r"^http://testserver/somewhere/else/\?c=\d+$", location) + self.failUnless(match != None, "Unexpected redirect location: %s" % location) def testCommentDoneView(self): a = Article.objects.get(pk=1) data = self.getValidData(a) response = self.client.post("/post/", data) - response = self.client.get("/posted/", {'c':1}) + location = response["Location"] + match = post_redirect_re.match(location) + self.failUnless(match != None, "Unexpected redirect location: %s" % location) + pk = int(match.group('pk')) + response = self.client.get(location) self.assertTemplateUsed(response, "comments/posted.html") - self.assertEqual(response.context[0]["comment"], Comment.objects.get(pk=1)) + self.assertEqual(response.context[0]["comment"], Comment.objects.get(pk=pk)) diff --git a/tests/regressiontests/comment_tests/tests/moderation_view_tests.py b/tests/regressiontests/comment_tests/tests/moderation_view_tests.py index 2f6b51d709..b9eadd78b4 100644 --- a/tests/regressiontests/comment_tests/tests/moderation_view_tests.py +++ b/tests/regressiontests/comment_tests/tests/moderation_view_tests.py @@ -8,39 +8,43 @@ class FlagViewTests(CommentTestCase): def testFlagGet(self): """GET the flag view: render a confirmation page.""" - self.createSomeComments() + comments = self.createSomeComments() + pk = comments[0].pk self.client.login(username="normaluser", password="normaluser") - response = self.client.get("/flag/1/") + response = self.client.get("/flag/%d/" % pk) self.assertTemplateUsed(response, "comments/flag.html") def testFlagPost(self): """POST the flag view: actually flag the view (nice for XHR)""" - self.createSomeComments() + comments = self.createSomeComments() + pk = comments[0].pk self.client.login(username="normaluser", password="normaluser") - response = self.client.post("/flag/1/") - self.assertEqual(response["Location"], "http://testserver/flagged/?c=1") - c = Comment.objects.get(pk=1) + response = self.client.post("/flag/%d/" % pk) + self.assertEqual(response["Location"], "http://testserver/flagged/?c=%d" % pk) + c = Comment.objects.get(pk=pk) self.assertEqual(c.flags.filter(flag=CommentFlag.SUGGEST_REMOVAL).count(), 1) return c def testFlagPostTwice(self): """Users don't get to flag comments more than once.""" c = self.testFlagPost() - self.client.post("/flag/1/") - self.client.post("/flag/1/") + self.client.post("/flag/%d/" % c.pk) + self.client.post("/flag/%d/" % c.pk) self.assertEqual(c.flags.filter(flag=CommentFlag.SUGGEST_REMOVAL).count(), 1) def testFlagAnon(self): """GET/POST the flag view while not logged in: redirect to log in.""" - self.createSomeComments() - response = self.client.get("/flag/1/") - self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/1/") - response = self.client.post("/flag/1/") - self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/1/") + comments = self.createSomeComments() + pk = comments[0].pk + response = self.client.get("/flag/%d/" % pk) + self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/%d/" % pk) + response = self.client.post("/flag/%d/" % pk) + self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/flag/%d/" % pk) def testFlaggedView(self): - self.createSomeComments() - response = self.client.get("/flagged/", data={"c":1}) + comments = self.createSomeComments() + pk = comments[0].pk + response = self.client.get("/flagged/", data={"c":pk}) self.assertTemplateUsed(response, "comments/flagged.html") def testFlagSignals(self): @@ -70,23 +74,25 @@ class DeleteViewTests(CommentTestCase): def testDeletePermissions(self): """The delete view should only be accessible to 'moderators'""" - self.createSomeComments() + comments = self.createSomeComments() + pk = comments[0].pk self.client.login(username="normaluser", password="normaluser") - response = self.client.get("/delete/1/") - self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/delete/1/") + response = self.client.get("/delete/%d/" % pk) + self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/delete/%d/" % pk) makeModerator("normaluser") - response = self.client.get("/delete/1/") + response = self.client.get("/delete/%d/" % pk) self.assertEqual(response.status_code, 200) def testDeletePost(self): """POSTing the delete view should mark the comment as removed""" - self.createSomeComments() + comments = self.createSomeComments() + pk = comments[0].pk makeModerator("normaluser") self.client.login(username="normaluser", password="normaluser") - response = self.client.post("/delete/1/") - self.assertEqual(response["Location"], "http://testserver/deleted/?c=1") - c = Comment.objects.get(pk=1) + response = self.client.post("/delete/%d/" % pk) + self.assertEqual(response["Location"], "http://testserver/deleted/?c=%d" % pk) + c = Comment.objects.get(pk=pk) self.failUnless(c.is_removed) self.assertEqual(c.flags.filter(flag=CommentFlag.MODERATOR_DELETION, user__username="normaluser").count(), 1) @@ -103,21 +109,23 @@ class DeleteViewTests(CommentTestCase): self.assertEqual(received_signals, [signals.comment_was_flagged]) def testDeletedView(self): - self.createSomeComments() - response = self.client.get("/deleted/", data={"c":1}) + comments = self.createSomeComments() + pk = comments[0].pk + response = self.client.get("/deleted/", data={"c":pk}) self.assertTemplateUsed(response, "comments/deleted.html") class ApproveViewTests(CommentTestCase): def testApprovePermissions(self): """The delete view should only be accessible to 'moderators'""" - self.createSomeComments() + comments = self.createSomeComments() + pk = comments[0].pk self.client.login(username="normaluser", password="normaluser") - response = self.client.get("/approve/1/") - self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/approve/1/") + response = self.client.get("/approve/%d/" % pk) + self.assertEqual(response["Location"], "http://testserver/accounts/login/?next=/approve/%d/" % pk) makeModerator("normaluser") - response = self.client.get("/approve/1/") + response = self.client.get("/approve/%d/" % pk) self.assertEqual(response.status_code, 200) def testApprovePost(self): @@ -127,9 +135,9 @@ class ApproveViewTests(CommentTestCase): makeModerator("normaluser") self.client.login(username="normaluser", password="normaluser") - response = self.client.post("/approve/1/") - self.assertEqual(response["Location"], "http://testserver/approved/?c=1") - c = Comment.objects.get(pk=1) + response = self.client.post("/approve/%d/" % c1.pk) + self.assertEqual(response["Location"], "http://testserver/approved/?c=%d" % c1.pk) + c = Comment.objects.get(pk=c1.pk) self.failUnless(c.is_public) self.assertEqual(c.flags.filter(flag=CommentFlag.MODERATOR_APPROVAL, user__username="normaluser").count(), 1) @@ -146,8 +154,9 @@ class ApproveViewTests(CommentTestCase): self.assertEqual(received_signals, [signals.comment_was_flagged]) def testApprovedView(self): - self.createSomeComments() - response = self.client.get("/approved/", data={"c":1}) + comments = self.createSomeComments() + pk = comments[0].pk + response = self.client.get("/approved/", data={"c":pk}) self.assertTemplateUsed(response, "comments/approved.html") diff --git a/tests/regressiontests/file_uploads/tests.py b/tests/regressiontests/file_uploads/tests.py index 6fcd8a99aa..21f8ad4de2 100644 --- a/tests/regressiontests/file_uploads/tests.py +++ b/tests/regressiontests/file_uploads/tests.py @@ -238,6 +238,9 @@ class DirectoryCreationTests(unittest.TestCase): self.obj = FileModel() if not os.path.isdir(temp_storage.location): os.makedirs(temp_storage.location) + if os.path.isdir(UPLOAD_TO): + os.chmod(UPLOAD_TO, 0700) + shutil.rmtree(UPLOAD_TO) def tearDown(self): os.chmod(temp_storage.location, 0700) diff --git a/tests/regressiontests/generic_inline_admin/tests.py b/tests/regressiontests/generic_inline_admin/tests.py index e03cc1f2f4..a3ea5fc60b 100644 --- a/tests/regressiontests/generic_inline_admin/tests.py +++ b/tests/regressiontests/generic_inline_admin/tests.py @@ -21,8 +21,10 @@ class GenericAdminViewTest(TestCase): # relies on content type IDs, which will vary depending on what # other tests have been run), thus we do it here. e = Episode.objects.create(name='This Week in Django') + self.episode_pk = e.pk m = Media(content_object=e, url='http://example.com/podcast.mp3') m.save() + self.media_pk = m.pk def tearDown(self): self.client.logout() @@ -39,7 +41,7 @@ class GenericAdminViewTest(TestCase): """ A smoke test to ensure GET on the change_view works. """ - response = self.client.get('/generic_inline_admin/admin/generic_inline_admin/episode/1/') + response = self.client.get('/generic_inline_admin/admin/generic_inline_admin/episode/%d/' % self.episode_pk) self.failUnlessEqual(response.status_code, 200) def testBasicAddPost(self): @@ -64,10 +66,11 @@ class GenericAdminViewTest(TestCase): # inline data "generic_inline_admin-media-content_type-object_id-TOTAL_FORMS": u"2", "generic_inline_admin-media-content_type-object_id-INITIAL_FORMS": u"1", - "generic_inline_admin-media-content_type-object_id-0-id": u"1", + "generic_inline_admin-media-content_type-object_id-0-id": u"%d" % self.media_pk, "generic_inline_admin-media-content_type-object_id-0-url": u"http://example.com/podcast.mp3", "generic_inline_admin-media-content_type-object_id-1-id": u"", "generic_inline_admin-media-content_type-object_id-1-url": u"", } - response = self.client.post('/generic_inline_admin/admin/generic_inline_admin/episode/1/', post_data) + url = '/generic_inline_admin/admin/generic_inline_admin/episode/%d/' % self.episode_pk + response = self.client.post(url, post_data) self.failUnlessEqual(response.status_code, 302) # redirect somewhere