From f556df90be995a83b979cf875705d98521ab4dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anssi=20K=C3=A4=C3=A4ri=C3=A4inen?= Date: Sat, 2 Feb 2013 14:48:55 +0200 Subject: [PATCH] Fixed #19645 -- Added tests for TransactionMiddleware --- tests/regressiontests/middleware/models.py | 14 +++++- tests/regressiontests/middleware/tests.py | 52 +++++++++++++++++++++- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/tests/regressiontests/middleware/models.py b/tests/regressiontests/middleware/models.py index 71abcc51987..7088bfc2f39 100644 --- a/tests/regressiontests/middleware/models.py +++ b/tests/regressiontests/middleware/models.py @@ -1 +1,13 @@ -# models.py file for tests to run. +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Band(models.Model): + name = models.CharField(max_length=100) + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name diff --git a/tests/regressiontests/middleware/tests.py b/tests/regressiontests/middleware/tests.py index c6d42a6964b..a9a45c99ba3 100644 --- a/tests/regressiontests/middleware/tests.py +++ b/tests/regressiontests/middleware/tests.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from __future__ import absolute_import, unicode_literals import gzip from io import BytesIO @@ -8,17 +9,22 @@ import warnings from django.conf import settings from django.core import mail +from django.db import transaction from django.http import HttpRequest from django.http import HttpResponse, StreamingHttpResponse from django.middleware.clickjacking import XFrameOptionsMiddleware from django.middleware.common import CommonMiddleware, BrokenLinkEmailsMiddleware from django.middleware.http import ConditionalGetMiddleware from django.middleware.gzip import GZipMiddleware -from django.test import TestCase, RequestFactory +from django.middleware.transaction import TransactionMiddleware +from django.test import TransactionTestCase, TestCase, RequestFactory from django.test.utils import override_settings from django.utils import six +from django.utils.encoding import force_str from django.utils.six.moves import xrange +from .models import Band + class CommonMiddlewareTest(TestCase): @@ -273,7 +279,7 @@ class CommonMiddlewareTest(TestCase): def test_non_ascii_query_string_does_not_crash(self): """Regression test for #15152""" request = self._get_request('slash') - request.META['QUERY_STRING'] = 'drink=café' + request.META['QUERY_STRING'] = force_str('drink=café') response = CommonMiddleware().process_request(request) self.assertEqual(response.status_code, 301) @@ -662,3 +668,45 @@ class ETagGZipMiddlewareTest(TestCase): nogzip_etag = response.get('ETag') self.assertNotEqual(gzip_etag, nogzip_etag) + +class TransactionMiddlewareTest(TransactionTestCase): + """ + Test the transaction middleware. + """ + def setUp(self): + self.request = HttpRequest() + self.request.META = { + 'SERVER_NAME': 'testserver', + 'SERVER_PORT': 80, + } + self.request.path = self.request.path_info = "/" + self.response = HttpResponse() + self.response.status_code = 200 + + def test_request(self): + TransactionMiddleware().process_request(self.request) + self.assertTrue(transaction.is_managed()) + + def test_managed_response(self): + transaction.enter_transaction_management() + transaction.managed(True) + Band.objects.create(name='The Beatles') + self.assertTrue(transaction.is_dirty()) + TransactionMiddleware().process_response(self.request, self.response) + self.assertFalse(transaction.is_dirty()) + self.assertEqual(Band.objects.count(), 1) + + def test_unmanaged_response(self): + transaction.managed(False) + TransactionMiddleware().process_response(self.request, self.response) + self.assertFalse(transaction.is_managed()) + self.assertFalse(transaction.is_dirty()) + + def test_exception(self): + transaction.enter_transaction_management() + transaction.managed(True) + Band.objects.create(name='The Beatles') + self.assertTrue(transaction.is_dirty()) + TransactionMiddleware().process_exception(self.request, None) + self.assertEqual(Band.objects.count(), 0) + self.assertFalse(transaction.is_dirty())