diff --git a/django/test/utils.py b/django/test/utils.py index 2672594916..bf1dc4f165 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -6,12 +6,15 @@ from django.conf import settings from django.core import mail from django.core.mail.backends import locmem from django.test import signals -from django.template import Template +from django.template import Template, loader, TemplateDoesNotExist +from django.template.loaders import cached from django.utils.translation import deactivate __all__ = ('Approximate', 'ContextList', 'setup_test_environment', 'teardown_test_environment', 'get_runner') +RESTORE_LOADERS_ATTR = '_original_template_source_loaders' + class Approximate(object): def __init__(self, val, places=7): @@ -119,3 +122,41 @@ def get_runner(settings): test_module = __import__(test_module_name, {}, {}, test_path[-1]) test_runner = getattr(test_module, test_path[-1]) return test_runner + + +def setup_test_template_loader(templates_dict, use_cached_loader=False): + """ + Changes Django to only find templates from within a dictionary (where each + key is the template name and each value is the corresponding template + content to return). + + Use meth:`restore_template_loaders` to restore the original loaders. + """ + if hasattr(loader, RESTORE_LOADERS_ATTR): + raise Exception("loader.%s already exists" % RESTORE_LOADERS_ATTR) + + def test_template_loader(template_name, template_dirs=None): + "A custom template loader that loads templates from a dictionary." + try: + return (templates_dict[template_name], "test:%s" % template_name) + except KeyError: + raise TemplateDoesNotExist(template_name) + + if use_cached_loader: + template_loader = cached.Loader(('test_template_loader',)) + template_loader._cached_loaders = (test_template_loader,) + else: + template_loader = test_template_loader + + setattr(loader, RESTORE_LOADERS_ATTR, loader.template_source_loaders) + loader.template_source_loaders = (template_loader,) + return template_loader + + +def restore_template_loaders(): + """ + Restores the original template loaders after + :meth:`setup_test_template_loader` has been run. + """ + loader.template_source_loaders = getattr(loader, RESTORE_LOADERS_ATTR) + delattr(loader, RESTORE_LOADERS_ATTR) diff --git a/tests/regressiontests/templates/tests.py b/tests/regressiontests/templates/tests.py index e8881fc83b..a051d7335e 100644 --- a/tests/regressiontests/templates/tests.py +++ b/tests/regressiontests/templates/tests.py @@ -18,7 +18,8 @@ from django.template import base as template_base from django.core import urlresolvers from django.template import loader from django.template.loaders import app_directories, filesystem, cached -from django.test.utils import get_warnings_state, restore_warnings_state +from django.test.utils import get_warnings_state, restore_warnings_state,\ + setup_test_template_loader, restore_template_loaders from django.utils import unittest from django.utils.translation import activate, deactivate, ugettext as _ from django.utils.safestring import mark_safe @@ -390,19 +391,10 @@ class Templates(unittest.TestCase): template_tests.update(filter_tests) - # Register our custom template loader. - def test_template_loader(template_name, template_dirs=None): - "A custom template loader that loads the unit-test templates." - try: - return (template_tests[template_name][0] , "test:%s" % template_name) - except KeyError: - raise template.TemplateDoesNotExist(template_name) - - cache_loader = cached.Loader(('test_template_loader',)) - cache_loader._cached_loaders = (test_template_loader,) - - old_template_loaders = loader.template_source_loaders - loader.template_source_loaders = [cache_loader] + cache_loader = setup_test_template_loader( + dict([(name, t[0]) for name, t in template_tests.iteritems()]), + use_cached_loader=True, + ) failures = [] tests = template_tests.items() @@ -490,7 +482,7 @@ class Templates(unittest.TestCase): expected_invalid_str = 'INVALID' template_base.invalid_var_format_string = False - loader.template_source_loaders = old_template_loaders + restore_template_loaders() deactivate() settings.TEMPLATE_DEBUG = old_td settings.TEMPLATE_STRING_IF_INVALID = old_invalid