Replaced django.test.utils.patch_logger() with assertLogs().

Thanks Tim Graham for the review.
This commit is contained in:
Claude Paroz 2018-04-28 15:20:27 +02:00 committed by Tim Graham
parent 7d3fe36c62
commit 607970f31c
13 changed files with 97 additions and 157 deletions

View File

@ -634,26 +634,6 @@ class ignore_warnings(TestContextDecorator):
self.catch_warnings.__exit__(*sys.exc_info()) self.catch_warnings.__exit__(*sys.exc_info())
@contextmanager
def patch_logger(logger_name, log_level, log_kwargs=False):
"""
Context manager that takes a named logger and the logging level
and provides a simple mock-like list of messages received
"""
calls = []
def replacement(msg, *args, **kwargs):
call = msg % args
calls.append((call, kwargs) if log_kwargs else call)
logger = logging.getLogger(logger_name)
orig = getattr(logger, log_level)
setattr(logger, log_level, replacement)
try:
yield calls
finally:
setattr(logger, log_level, orig)
# On OSes that don't provide tzset (Windows), we can't set the timezone # On OSes that don't provide tzset (Windows), we can't set the timezone
# in which the program runs. As a consequence, we must skip tests that # in which the program runs. As a consequence, we must skip tests that
# don't enforce a specific timezone (with timezone.override or equivalent), # don't enforce a specific timezone (with timezone.override or equivalent),

View File

@ -25,7 +25,7 @@ from django.template.response import TemplateResponse
from django.test import ( from django.test import (
TestCase, modify_settings, override_settings, skipUnlessDBFeature, TestCase, modify_settings, override_settings, skipUnlessDBFeature,
) )
from django.test.utils import override_script_prefix, patch_logger from django.test.utils import override_script_prefix
from django.urls import NoReverseMatch, resolve, reverse from django.urls import NoReverseMatch, resolve, reverse
from django.utils import formats, translation from django.utils import formats, translation
from django.utils.cache import get_max_age from django.utils.cache import get_max_age
@ -747,12 +747,11 @@ class AdminViewBasicTest(AdminViewBasicTestCase):
self.assertContains(response, '%Y-%m-%d %H:%M:%S') self.assertContains(response, '%Y-%m-%d %H:%M:%S')
def test_disallowed_filtering(self): def test_disallowed_filtering(self):
with patch_logger('django.security.DisallowedModelAdminLookup', 'error') as calls: with self.assertLogs('django.security.DisallowedModelAdminLookup', 'ERROR'):
response = self.client.get( response = self.client.get(
"%s?owner__email__startswith=fuzzy" % reverse('admin:admin_views_album_changelist') "%s?owner__email__startswith=fuzzy" % reverse('admin:admin_views_album_changelist')
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
# Filters are allowed if explicitly included in list_filter # Filters are allowed if explicitly included in list_filter
response = self.client.get("%s?color__value__startswith=red" % reverse('admin:admin_views_thing_changelist')) response = self.client.get("%s?color__value__startswith=red" % reverse('admin:admin_views_thing_changelist'))
@ -777,18 +776,16 @@ class AdminViewBasicTest(AdminViewBasicTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_disallowed_to_field(self): def test_disallowed_to_field(self):
with patch_logger('django.security.DisallowedModelAdminToField', 'error') as calls: url = reverse('admin:admin_views_section_changelist')
url = reverse('admin:admin_views_section_changelist') with self.assertLogs('django.security.DisallowedModelAdminToField', 'ERROR'):
response = self.client.get(url, {TO_FIELD_VAR: 'missing_field'}) response = self.client.get(url, {TO_FIELD_VAR: 'missing_field'})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
# Specifying a field that is not referred by any other model registered # Specifying a field that is not referred by any other model registered
# to this admin site should raise an exception. # to this admin site should raise an exception.
with patch_logger('django.security.DisallowedModelAdminToField', 'error') as calls: with self.assertLogs('django.security.DisallowedModelAdminToField', 'ERROR'):
response = self.client.get(reverse('admin:admin_views_section_changelist'), {TO_FIELD_VAR: 'name'}) response = self.client.get(reverse('admin:admin_views_section_changelist'), {TO_FIELD_VAR: 'name'})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
# #23839 - Primary key should always be allowed, even if the referenced model isn't registered. # #23839 - Primary key should always be allowed, even if the referenced model isn't registered.
response = self.client.get(reverse('admin:admin_views_notreferenced_changelist'), {TO_FIELD_VAR: 'id'}) response = self.client.get(reverse('admin:admin_views_notreferenced_changelist'), {TO_FIELD_VAR: 'id'})
@ -815,30 +812,26 @@ class AdminViewBasicTest(AdminViewBasicTestCase):
# #25622 - Specifying a field of a model only referred by a generic # #25622 - Specifying a field of a model only referred by a generic
# relation should raise DisallowedModelAdminToField. # relation should raise DisallowedModelAdminToField.
url = reverse('admin:admin_views_referencedbygenrel_changelist') url = reverse('admin:admin_views_referencedbygenrel_changelist')
with patch_logger('django.security.DisallowedModelAdminToField', 'error') as calls: with self.assertLogs('django.security.DisallowedModelAdminToField', 'ERROR'):
response = self.client.get(url, {TO_FIELD_VAR: 'object_id'}) response = self.client.get(url, {TO_FIELD_VAR: 'object_id'})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
# We also want to prevent the add, change, and delete views from # We also want to prevent the add, change, and delete views from
# leaking a disallowed field value. # leaking a disallowed field value.
with patch_logger('django.security.DisallowedModelAdminToField', 'error') as calls: with self.assertLogs('django.security.DisallowedModelAdminToField', 'ERROR'):
response = self.client.post(reverse('admin:admin_views_section_add'), {TO_FIELD_VAR: 'name'}) response = self.client.post(reverse('admin:admin_views_section_add'), {TO_FIELD_VAR: 'name'})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
section = Section.objects.create() section = Section.objects.create()
with patch_logger('django.security.DisallowedModelAdminToField', 'error') as calls: url = reverse('admin:admin_views_section_change', args=(section.pk,))
url = reverse('admin:admin_views_section_change', args=(section.pk,)) with self.assertLogs('django.security.DisallowedModelAdminToField', 'ERROR'):
response = self.client.post(url, {TO_FIELD_VAR: 'name'}) response = self.client.post(url, {TO_FIELD_VAR: 'name'})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
with patch_logger('django.security.DisallowedModelAdminToField', 'error') as calls: url = reverse('admin:admin_views_section_delete', args=(section.pk,))
url = reverse('admin:admin_views_section_delete', args=(section.pk,)) with self.assertLogs('django.security.DisallowedModelAdminToField', 'ERROR'):
response = self.client.post(url, {TO_FIELD_VAR: 'name'}) response = self.client.post(url, {TO_FIELD_VAR: 'name'})
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(calls), 1)
def test_allowed_filtering_15103(self): def test_allowed_filtering_15103(self):
""" """

View File

@ -27,7 +27,6 @@ from django.http import HttpRequest, QueryDict
from django.middleware.csrf import CsrfViewMiddleware, get_token from django.middleware.csrf import CsrfViewMiddleware, get_token
from django.test import Client, TestCase, override_settings from django.test import Client, TestCase, override_settings
from django.test.client import RedirectCycleError from django.test.client import RedirectCycleError
from django.test.utils import patch_logger
from django.urls import NoReverseMatch, reverse, reverse_lazy from django.urls import NoReverseMatch, reverse, reverse_lazy
from django.utils.http import urlsafe_base64_encode from django.utils.http import urlsafe_base64_encode
from django.utils.translation import LANGUAGE_SESSION_KEY from django.utils.translation import LANGUAGE_SESSION_KEY
@ -186,29 +185,27 @@ class PasswordResetTest(AuthViewsTestCase):
# produce a meaningful reset URL, we need to be certain that the # produce a meaningful reset URL, we need to be certain that the
# HTTP_HOST header isn't poisoned. This is done as a check when get_host() # HTTP_HOST header isn't poisoned. This is done as a check when get_host()
# is invoked, but we check here as a practical consequence. # is invoked, but we check here as a practical consequence.
with patch_logger('django.security.DisallowedHost', 'error') as logger_calls: with self.assertLogs('django.security.DisallowedHost', 'ERROR'):
response = self.client.post( response = self.client.post(
'/password_reset/', '/password_reset/',
{'email': 'staffmember@example.com'}, {'email': 'staffmember@example.com'},
HTTP_HOST='www.example:dr.frankenstein@evil.tld' HTTP_HOST='www.example:dr.frankenstein@evil.tld'
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
self.assertEqual(len(logger_calls), 1)
# Skip any 500 handler action (like sending more mail...) # Skip any 500 handler action (like sending more mail...)
@override_settings(DEBUG_PROPAGATE_EXCEPTIONS=True) @override_settings(DEBUG_PROPAGATE_EXCEPTIONS=True)
def test_poisoned_http_host_admin_site(self): def test_poisoned_http_host_admin_site(self):
"Poisoned HTTP_HOST headers can't be used for reset emails on admin views" "Poisoned HTTP_HOST headers can't be used for reset emails on admin views"
with patch_logger('django.security.DisallowedHost', 'error') as logger_calls: with self.assertLogs('django.security.DisallowedHost', 'ERROR'):
response = self.client.post( response = self.client.post(
'/admin_password_reset/', '/admin_password_reset/',
{'email': 'staffmember@example.com'}, {'email': 'staffmember@example.com'},
HTTP_HOST='www.example:dr.frankenstein@evil.tld' HTTP_HOST='www.example:dr.frankenstein@evil.tld'
) )
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(mail.outbox), 0) self.assertEqual(len(mail.outbox), 0)
self.assertEqual(len(logger_calls), 1)
def _test_confirm_start(self): def _test_confirm_start(self):
# Start by creating the email # Start by creating the email
@ -1153,10 +1150,9 @@ class ChangelistTests(AuthViewsTestCase):
# repeated password__startswith queries. # repeated password__startswith queries.
def test_changelist_disallows_password_lookups(self): def test_changelist_disallows_password_lookups(self):
# A lookup that tries to filter on password isn't OK # A lookup that tries to filter on password isn't OK
with patch_logger('django.security.DisallowedModelAdminLookup', 'error') as logger_calls: with self.assertLogs('django.security.DisallowedModelAdminLookup', 'ERROR'):
response = self.client.get(reverse('auth_test_admin:auth_user_changelist') + '?password__startswith=sha1$') response = self.client.get(reverse('auth_test_admin:auth_user_changelist') + '?password__startswith=sha1$')
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
self.assertEqual(len(logger_calls), 1)
def test_user_change_email(self): def test_user_change_email(self):
data = self.get_user_data(self.admin) data = self.get_user_data(self.admin)

View File

@ -1,4 +1,3 @@
import logging
import re import re
from django.conf import settings from django.conf import settings
@ -10,7 +9,6 @@ from django.middleware.csrf import (
_compare_salted_tokens as equivalent_tokens, get_token, _compare_salted_tokens as equivalent_tokens, get_token,
) )
from django.test import SimpleTestCase, override_settings from django.test import SimpleTestCase, override_settings
from django.test.utils import patch_logger
from django.views.decorators.csrf import csrf_exempt, requires_csrf_token from django.views.decorators.csrf import csrf_exempt, requires_csrf_token
from .views import ( from .views import (
@ -98,24 +96,24 @@ class CsrfViewMiddlewareTestMixin:
If no CSRF cookies is present, the middleware rejects the incoming If no CSRF cookies is present, the middleware rejects the incoming
request. This will stop login CSRF. request. This will stop login CSRF.
""" """
with patch_logger('django.security.csrf', 'warning') as logger_calls: req = self._get_POST_no_csrf_cookie_request()
req = self._get_POST_no_csrf_cookie_request() self.mw.process_request(req)
self.mw.process_request(req) with self.assertLogs('django.security.csrf', 'WARNING') as cm:
req2 = self.mw.process_view(req, post_form_view, (), {}) req2 = self.mw.process_view(req, post_form_view, (), {})
self.assertEqual(403, req2.status_code) self.assertEqual(403, req2.status_code)
self.assertEqual(logger_calls[0], 'Forbidden (%s): ' % REASON_NO_CSRF_COOKIE) self.assertEqual(cm.records[0].getMessage(), 'Forbidden (%s): ' % REASON_NO_CSRF_COOKIE)
def test_process_request_csrf_cookie_no_token(self): def test_process_request_csrf_cookie_no_token(self):
""" """
If a CSRF cookie is present but no token, the middleware rejects If a CSRF cookie is present but no token, the middleware rejects
the incoming request. the incoming request.
""" """
with patch_logger('django.security.csrf', 'warning') as logger_calls: req = self._get_POST_csrf_cookie_request()
req = self._get_POST_csrf_cookie_request() self.mw.process_request(req)
self.mw.process_request(req) with self.assertLogs('django.security.csrf', 'WARNING') as cm:
req2 = self.mw.process_view(req, post_form_view, (), {}) req2 = self.mw.process_view(req, post_form_view, (), {})
self.assertEqual(403, req2.status_code) self.assertEqual(403, req2.status_code)
self.assertEqual(logger_calls[0], 'Forbidden (%s): ' % REASON_BAD_TOKEN) self.assertEqual(cm.records[0].getMessage(), 'Forbidden (%s): ' % REASON_BAD_TOKEN)
def test_process_request_csrf_cookie_and_token(self): def test_process_request_csrf_cookie_and_token(self):
""" """
@ -163,17 +161,17 @@ class CsrfViewMiddlewareTestMixin:
""" """
req = TestingHttpRequest() req = TestingHttpRequest()
req.method = 'PUT' req.method = 'PUT'
with patch_logger('django.security.csrf', 'warning') as logger_calls: with self.assertLogs('django.security.csrf', 'WARNING') as cm:
req2 = self.mw.process_view(req, post_form_view, (), {}) req2 = self.mw.process_view(req, post_form_view, (), {})
self.assertEqual(403, req2.status_code) self.assertEqual(403, req2.status_code)
self.assertEqual(logger_calls[0], 'Forbidden (%s): ' % REASON_NO_CSRF_COOKIE) self.assertEqual(cm.records[0].getMessage(), 'Forbidden (%s): ' % REASON_NO_CSRF_COOKIE)
req = TestingHttpRequest() req = TestingHttpRequest()
req.method = 'DELETE' req.method = 'DELETE'
with patch_logger('django.security.csrf', 'warning') as logger_calls: with self.assertLogs('django.security.csrf', 'WARNING') as cm:
req2 = self.mw.process_view(req, post_form_view, (), {}) req2 = self.mw.process_view(req, post_form_view, (), {})
self.assertEqual(403, req2.status_code) self.assertEqual(403, req2.status_code)
self.assertEqual(logger_calls[0], 'Forbidden (%s): ' % REASON_NO_CSRF_COOKIE) self.assertEqual(cm.records[0].getMessage(), 'Forbidden (%s): ' % REASON_NO_CSRF_COOKIE)
def test_put_and_delete_allowed(self): def test_put_and_delete_allowed(self):
""" """
@ -436,22 +434,10 @@ class CsrfViewMiddlewareTestMixin:
""" """
ensure_csrf_cookie() doesn't log warnings (#19436). ensure_csrf_cookie() doesn't log warnings (#19436).
""" """
class TestHandler(logging.Handler): with self.assertRaisesMessage(AssertionError, 'no logs'):
def emit(self, record): with self.assertLogs('django.request', 'WARNING'):
raise Exception("This shouldn't have happened!") req = self._get_GET_no_csrf_cookie_request()
ensure_csrf_cookie_view(req)
logger = logging.getLogger('django.request')
test_handler = TestHandler()
old_log_level = logger.level
try:
logger.addHandler(test_handler)
logger.setLevel(logging.WARNING)
req = self._get_GET_no_csrf_cookie_request()
ensure_csrf_cookie_view(req)
finally:
logger.removeHandler(test_handler)
logger.setLevel(old_log_level)
def test_post_data_read_failure(self): def test_post_data_read_failure(self):
""" """
@ -498,11 +484,11 @@ class CsrfViewMiddlewareTestMixin:
self.assertIsNone(resp) self.assertIsNone(resp)
req = CsrfPostRequest(token, raise_error=True) req = CsrfPostRequest(token, raise_error=True)
with patch_logger('django.security.csrf', 'warning') as logger_calls: self.mw.process_request(req)
self.mw.process_request(req) with self.assertLogs('django.security.csrf', 'WARNING') as cm:
resp = self.mw.process_view(req, post_form_view, (), {}) resp = self.mw.process_view(req, post_form_view, (), {})
self.assertEqual(resp.status_code, 403) self.assertEqual(resp.status_code, 403)
self.assertEqual(logger_calls[0], 'Forbidden (%s): ' % REASON_BAD_TOKEN) self.assertEqual(cm.records[0].getMessage(), 'Forbidden (%s): ' % REASON_BAD_TOKEN)
class CsrfViewMiddlewareTests(CsrfViewMiddlewareTestMixin, SimpleTestCase): class CsrfViewMiddlewareTests(CsrfViewMiddlewareTestMixin, SimpleTestCase):

View File

@ -1,7 +1,6 @@
from django.contrib.gis import admin from django.contrib.gis import admin
from django.contrib.gis.geos import Point from django.contrib.gis.geos import Point
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.test.utils import patch_logger
from .admin import UnmodifiableAdmin from .admin import UnmodifiableAdmin
from .models import City, site from .models import City, site
@ -73,28 +72,28 @@ class GeoAdminTest(TestCase):
def test_olwidget_empty_string(self): def test_olwidget_empty_string(self):
geoadmin = site._registry[City] geoadmin = site._registry[City]
form = geoadmin.get_changelist_form(None)({'point': ''}) form = geoadmin.get_changelist_form(None)({'point': ''})
with patch_logger('django.contrib.gis', 'error') as logger_calls: with self.assertRaisesMessage(AssertionError, 'no logs'):
output = str(form['point']) with self.assertLogs('django.contrib.gis', 'ERROR'):
output = str(form['point'])
self.assertInHTML( self.assertInHTML(
'<textarea id="id_point" class="vWKTField required" cols="150"' '<textarea id="id_point" class="vWKTField required" cols="150"'
' rows="10" name="point"></textarea>', ' rows="10" name="point"></textarea>',
output output
) )
self.assertEqual(logger_calls, [])
def test_olwidget_invalid_string(self): def test_olwidget_invalid_string(self):
geoadmin = site._registry[City] geoadmin = site._registry[City]
form = geoadmin.get_changelist_form(None)({'point': 'INVALID()'}) form = geoadmin.get_changelist_form(None)({'point': 'INVALID()'})
with patch_logger('django.contrib.gis', 'error') as logger_calls: with self.assertLogs('django.contrib.gis', 'ERROR') as cm:
output = str(form['point']) output = str(form['point'])
self.assertInHTML( self.assertInHTML(
'<textarea id="id_point" class="vWKTField required" cols="150"' '<textarea id="id_point" class="vWKTField required" cols="150"'
' rows="10" name="point"></textarea>', ' rows="10" name="point"></textarea>',
output output
) )
self.assertEqual(len(logger_calls), 1) self.assertEqual(len(cm.records), 1)
self.assertEqual( self.assertEqual(
logger_calls[0], cm.records[0].getMessage(),
"Error creating geometry from value 'INVALID()' (String input " "Error creating geometry from value 'INVALID()' (String input "
"unrecognized as WKT EWKT, and HEXEWKB.)" "unrecognized as WKT EWKT, and HEXEWKB.)"
) )

View File

@ -5,7 +5,6 @@ from django.contrib.gis.forms import BaseGeometryWidget, OpenLayersWidget
from django.contrib.gis.geos import GEOSGeometry from django.contrib.gis.geos import GEOSGeometry
from django.forms import ValidationError from django.forms import ValidationError
from django.test import SimpleTestCase, override_settings from django.test import SimpleTestCase, override_settings
from django.test.utils import patch_logger
from django.utils.html import escape from django.utils.html import escape
@ -120,7 +119,7 @@ class GeometryFieldTest(SimpleTestCase):
'pt3': 'PNT(0)', # invalid 'pt3': 'PNT(0)', # invalid
}) })
with patch_logger('django.contrib.gis', 'error') as logger_calls: with self.assertLogs('django.contrib.gis', 'ERROR') as logger_calls:
output = str(form) output = str(form)
# The first point can't use assertInHTML() due to non-deterministic # The first point can't use assertInHTML() due to non-deterministic
@ -142,9 +141,9 @@ class GeometryFieldTest(SimpleTestCase):
) )
# Only the invalid PNT(0) triggers an error log entry. # Only the invalid PNT(0) triggers an error log entry.
# Deserialization is called in form clean and in widget rendering. # Deserialization is called in form clean and in widget rendering.
self.assertEqual(len(logger_calls), 2) self.assertEqual(len(logger_calls.records), 2)
self.assertEqual( self.assertEqual(
logger_calls[0], logger_calls.records[0].getMessage(),
"Error creating geometry from value 'PNT(0)' (String input " "Error creating geometry from value 'PNT(0)' (String input "
"unrecognized as WKT EWKT, and HEXEWKB.)" "unrecognized as WKT EWKT, and HEXEWKB.)"
) )

View File

@ -1,7 +1,6 @@
from django.conf import settings from django.conf import settings
from django.core.exceptions import MiddlewareNotUsed from django.core.exceptions import MiddlewareNotUsed
from django.test import RequestFactory, SimpleTestCase, override_settings from django.test import RequestFactory, SimpleTestCase, override_settings
from django.test.utils import patch_logger
from . import middleware as mw from . import middleware as mw
@ -138,26 +137,24 @@ class MiddlewareNotUsedTests(SimpleTestCase):
@override_settings(MIDDLEWARE=['middleware_exceptions.tests.MyMiddleware']) @override_settings(MIDDLEWARE=['middleware_exceptions.tests.MyMiddleware'])
def test_log(self): def test_log(self):
with patch_logger('django.request', 'debug') as calls: with self.assertLogs('django.request', 'DEBUG') as cm:
self.client.get('/middleware_exceptions/view/') self.client.get('/middleware_exceptions/view/')
self.assertEqual(len(calls), 1)
self.assertEqual( self.assertEqual(
calls[0], cm.records[0].getMessage(),
"MiddlewareNotUsed: 'middleware_exceptions.tests.MyMiddleware'" "MiddlewareNotUsed: 'middleware_exceptions.tests.MyMiddleware'"
) )
@override_settings(MIDDLEWARE=['middleware_exceptions.tests.MyMiddlewareWithExceptionMessage']) @override_settings(MIDDLEWARE=['middleware_exceptions.tests.MyMiddlewareWithExceptionMessage'])
def test_log_custom_message(self): def test_log_custom_message(self):
with patch_logger('django.request', 'debug') as calls: with self.assertLogs('django.request', 'DEBUG') as cm:
self.client.get('/middleware_exceptions/view/') self.client.get('/middleware_exceptions/view/')
self.assertEqual(len(calls), 1)
self.assertEqual( self.assertEqual(
calls[0], cm.records[0].getMessage(),
"MiddlewareNotUsed('middleware_exceptions.tests.MyMiddlewareWithExceptionMessage'): spam eggs" "MiddlewareNotUsed('middleware_exceptions.tests.MyMiddlewareWithExceptionMessage'): spam eggs"
) )
@override_settings(DEBUG=False) @override_settings(DEBUG=False)
def test_do_not_log_when_debug_is_false(self): def test_do_not_log_when_debug_is_false(self):
with patch_logger('django.request', 'debug') as calls: with self.assertRaisesMessage(AssertionError, 'no logs'):
self.client.get('/middleware_exceptions/view/') with self.assertLogs('django.request', 'DEBUG'):
self.assertEqual(len(calls), 0) self.client.get('/middleware_exceptions/view/')

View File

@ -1,6 +1,5 @@
from django.db import connection from django.db import connection
from django.test import TestCase from django.test import TestCase
from django.test.utils import patch_logger
class SchemaLoggerTests(TestCase): class SchemaLoggerTests(TestCase):
@ -9,12 +8,11 @@ class SchemaLoggerTests(TestCase):
editor = connection.schema_editor(collect_sql=True) editor = connection.schema_editor(collect_sql=True)
sql = 'SELECT * FROM foo WHERE id in (%s, %s)' sql = 'SELECT * FROM foo WHERE id in (%s, %s)'
params = [42, 1337] params = [42, 1337]
with patch_logger('django.db.backends.schema', 'debug', log_kwargs=True) as logger: with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm:
editor.execute(sql, params) editor.execute(sql, params)
self.assertEqual(cm.records[0].sql, sql)
self.assertEqual(cm.records[0].params, params)
self.assertEqual( self.assertEqual(
logger, cm.records[0].getMessage(),
[( 'SELECT * FROM foo WHERE id in (%s, %s); (params [42, 1337])',
'SELECT * FROM foo WHERE id in (%s, %s); (params [42, 1337])',
{'extra': {'sql': sql, 'params': params}},
)]
) )

View File

@ -22,7 +22,7 @@ from django.db.transaction import TransactionManagementError, atomic
from django.test import ( from django.test import (
TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature,
) )
from django.test.utils import CaptureQueriesContext, isolate_apps, patch_logger from django.test.utils import CaptureQueriesContext, isolate_apps
from django.utils import timezone from django.utils import timezone
from .fields import ( from .fields import (
@ -1573,11 +1573,11 @@ class SchemaTests(TransactionTestCase):
new_field = CharField(max_length=255, unique=True) new_field = CharField(max_length=255, unique=True)
new_field.model = Author new_field.model = Author
new_field.set_attributes_from_name('name') new_field.set_attributes_from_name('name')
with patch_logger('django.db.backends.schema', 'debug') as logger_calls: with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm:
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.alter_field(Author, Author._meta.get_field('name'), new_field) editor.alter_field(Author, Author._meta.get_field('name'), new_field)
# One SQL statement is executed to alter the field. # One SQL statement is executed to alter the field.
self.assertEqual(len(logger_calls), 1) self.assertEqual(len(cm.records), 1)
@isolate_apps('schema') @isolate_apps('schema')
@unittest.skipIf(connection.vendor == 'sqlite', 'SQLite remakes the table on field alteration.') @unittest.skipIf(connection.vendor == 'sqlite', 'SQLite remakes the table on field alteration.')
@ -1606,11 +1606,11 @@ class SchemaTests(TransactionTestCase):
new_field = SlugField(max_length=75, unique=True) new_field = SlugField(max_length=75, unique=True)
new_field.model = Tag new_field.model = Tag
new_field.set_attributes_from_name('slug') new_field.set_attributes_from_name('slug')
with patch_logger('django.db.backends.schema', 'debug') as logger_calls: with self.assertLogs('django.db.backends.schema', 'DEBUG') as cm:
with connection.schema_editor() as editor: with connection.schema_editor() as editor:
editor.alter_field(Tag, Tag._meta.get_field('slug'), new_field) editor.alter_field(Tag, Tag._meta.get_field('slug'), new_field)
# One SQL statement is executed to alter the field. # One SQL statement is executed to alter the field.
self.assertEqual(len(logger_calls), 1) self.assertEqual(len(cm.records), 1)
# Ensure that the field is still unique. # Ensure that the field is still unique.
Tag.objects.create(title='foo', slug='foo') Tag.objects.create(title='foo', slug='foo')
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):

View File

@ -5,7 +5,6 @@ from django.core.handlers.wsgi import WSGIRequest
from django.core.servers.basehttp import WSGIRequestHandler from django.core.servers.basehttp import WSGIRequestHandler
from django.test import SimpleTestCase from django.test import SimpleTestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.test.utils import patch_logger
class Stub: class Stub:
@ -34,22 +33,19 @@ class WSGIRequestHandlerTestCase(SimpleTestCase):
'error': [500, 503], 'error': [500, 503],
} }
def _log_level_code(level, status_code):
with patch_logger('django.server', level) as messages:
handler.log_message('GET %s %s', 'A', str(status_code))
return messages
for level, status_codes in level_status_codes.items(): for level, status_codes in level_status_codes.items():
for status_code in status_codes: for status_code in status_codes:
# The correct level gets the message. # The correct level gets the message.
messages = _log_level_code(level, status_code) with self.assertLogs('django.server', level.upper()) as cm:
self.assertIn('GET A %d' % status_code, messages[0]) handler.log_message('GET %s %s', 'A', str(status_code))
self.assertIn('GET A %d' % status_code, cm.output[0])
# Incorrect levels shouldn't have any messages. # Incorrect levels shouldn't have any messages.
for wrong_level in level_status_codes: for wrong_level in level_status_codes:
if wrong_level != level: if wrong_level != level:
messages = _log_level_code(wrong_level, status_code) with self.assertRaisesMessage(AssertionError, 'no logs'):
self.assertEqual(len(messages), 0) with self.assertLogs('django.template', level.upper()):
handler.log_message('GET %s %s', 'A', str(status_code))
finally: finally:
logger.handlers = original_handlers logger.handlers = original_handlers
@ -59,12 +55,12 @@ class WSGIRequestHandlerTestCase(SimpleTestCase):
handler = WSGIRequestHandler(request, '192.168.0.2', None) handler = WSGIRequestHandler(request, '192.168.0.2', None)
with patch_logger('django.server', 'error') as messages: with self.assertLogs('django.server', 'ERROR') as cm:
handler.log_message("GET %s %s", '\x16\x03', "4") handler.log_message("GET %s %s", '\x16\x03', "4")
self.assertIn( self.assertIn(
"You're accessing the development server over HTTPS, " "You're accessing the development server over HTTPS, "
"but it only supports HTTP.", "but it only supports HTTP.",
messages[0] cm.records[0].getMessage()
) )
def test_strips_underscore_headers(self): def test_strips_underscore_headers(self):
@ -107,8 +103,8 @@ class WSGIRequestHandlerTestCase(SimpleTestCase):
request = Stub(makefile=makefile) request = Stub(makefile=makefile)
server = Stub(base_environ={}, get_app=lambda: test_app) server = Stub(base_environ={}, get_app=lambda: test_app)
# We don't need to check stderr, but we don't want it in test output # Prevent logging from appearing in test output.
with patch_logger('django.server', 'info'): with self.assertLogs('django.server', 'INFO'):
# instantiating a handler runs the request as side effect # instantiating a handler runs the request as side effect
WSGIRequestHandler(request, '192.168.0.2', server) WSGIRequestHandler(request, '192.168.0.2', server)

View File

@ -32,7 +32,6 @@ from django.http import HttpResponse
from django.test import ( from django.test import (
RequestFactory, TestCase, ignore_warnings, override_settings, RequestFactory, TestCase, ignore_warnings, override_settings,
) )
from django.test.utils import patch_logger
from django.utils import timezone from django.utils import timezone
from .models import SessionStore as CustomDatabaseSession from .models import SessionStore as CustomDatabaseSession
@ -313,11 +312,10 @@ class SessionTestsMixin:
def test_decode_failure_logged_to_security(self): def test_decode_failure_logged_to_security(self):
bad_encode = base64.b64encode(b'flaskdj:alkdjf') bad_encode = base64.b64encode(b'flaskdj:alkdjf')
with patch_logger('django.security.SuspiciousSession', 'warning') as calls: with self.assertLogs('django.security.SuspiciousSession', 'WARNING') as cm:
self.assertEqual({}, self.session.decode(bad_encode)) self.assertEqual({}, self.session.decode(bad_encode))
# check that the failed decode is logged # The failed decode is logged.
self.assertEqual(len(calls), 1) self.assertIn('corrupted', cm.output[0])
self.assertIn('corrupted', calls[0])
def test_actual_expiry(self): def test_actual_expiry(self):
# this doesn't work with JSONSerializer (serializing timedelta) # this doesn't work with JSONSerializer (serializing timedelta)

View File

@ -5,13 +5,13 @@ from unittest import mock
from django import __version__ from django import __version__
from django.core.management import CommandError, call_command from django.core.management import CommandError, call_command
from django.test import SimpleTestCase from django.test import SimpleTestCase
from django.test.utils import captured_stdin, captured_stdout, patch_logger from django.test.utils import captured_stdin, captured_stdout
class ShellCommandTestCase(SimpleTestCase): class ShellCommandTestCase(SimpleTestCase):
def test_command_option(self): def test_command_option(self):
with patch_logger('test', 'info') as logger: with self.assertLogs('test', 'INFO') as cm:
call_command( call_command(
'shell', 'shell',
command=( command=(
@ -19,8 +19,7 @@ class ShellCommandTestCase(SimpleTestCase):
'getLogger("test").info(django.__version__)' 'getLogger("test").info(django.__version__)'
), ),
) )
self.assertEqual(len(logger), 1) self.assertEqual(cm.records[0].getMessage(), __version__)
self.assertEqual(logger[0], __version__)
@unittest.skipIf(sys.platform == 'win32', "Windows select() doesn't support file descriptors.") @unittest.skipIf(sys.platform == 'win32', "Windows select() doesn't support file descriptors.")
@mock.patch('django.core.management.commands.shell.select') @mock.patch('django.core.management.commands.shell.select')

View File

@ -14,7 +14,7 @@ from django.db import DatabaseError, connection
from django.shortcuts import render from django.shortcuts import render
from django.template import TemplateDoesNotExist from django.template import TemplateDoesNotExist
from django.test import RequestFactory, SimpleTestCase, override_settings from django.test import RequestFactory, SimpleTestCase, override_settings
from django.test.utils import LoggingCaptureMixin, patch_logger from django.test.utils import LoggingCaptureMixin
from django.urls import reverse from django.urls import reverse
from django.utils.functional import SimpleLazyObject from django.utils.functional import SimpleLazyObject
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
@ -253,9 +253,8 @@ class NonDjangoTemplatesDebugViewTests(SimpleTestCase):
def test_400(self): def test_400(self):
# When DEBUG=True, technical_500_template() is called. # When DEBUG=True, technical_500_template() is called.
with patch_logger('django.security.SuspiciousOperation', 'error'): response = self.client.get('/raises400/')
response = self.client.get('/raises400/') self.assertContains(response, '<div class="context" id="', status_code=400)
self.assertContains(response, '<div class="context" id="', status_code=400)
def test_403(self): def test_403(self):
response = self.client.get('/raises403/') response = self.client.get('/raises403/')