diff --git a/django/template/__init__.py b/django/template/__init__.py index 4cf3304eb6..abd46955ad 100644 --- a/django/template/__init__.py +++ b/django/template/__init__.py @@ -60,6 +60,8 @@ from django.conf import settings from django.template.context import Context, RequestContext, ContextPopException from django.utils.functional import curry from django.utils.text import smart_split +from django.dispatch import dispatcher +from django.template import signals __all__ = ('Template', 'Context', 'RequestContext', 'compile_string') @@ -137,13 +139,14 @@ class StringOrigin(Origin): return self.source class Template(object): - def __init__(self, template_string, origin=None): + def __init__(self, template_string, origin=None, name=''): "Compilation stage" if settings.TEMPLATE_DEBUG and origin == None: origin = StringOrigin(template_string) # Could do some crazy stack-frame stuff to record where this string # came from... self.nodelist = compile_string(template_string, origin) + self.name = name def __iter__(self): for node in self.nodelist: @@ -152,6 +155,7 @@ class Template(object): def render(self, context): "Display stage -- can be called many times" + dispatcher.send(signal=signals.template_rendered, sender=self, template=self, context=context) return self.nodelist.render(context) def compile_string(template_string, origin): diff --git a/django/template/defaulttags.py b/django/template/defaulttags.py index 0a4fe33d82..d111ddca89 100644 --- a/django/template/defaulttags.py +++ b/django/template/defaulttags.py @@ -251,7 +251,7 @@ class SsiNode(Node): output = '' if self.parsed: try: - t = Template(output) + t = Template(output, name=self.filepath) return t.render(context) except TemplateSyntaxError, e: if settings.DEBUG: diff --git a/django/template/loader.py b/django/template/loader.py index 60f24554f1..03e6f8d49d 100644 --- a/django/template/loader.py +++ b/django/template/loader.py @@ -76,14 +76,16 @@ def get_template(template_name): Returns a compiled Template object for the given template name, handling template inheritance recursively. """ - return get_template_from_string(*find_template_source(template_name)) + source, origin = find_template_source(template_name) + template = get_template_from_string(source, origin, template_name) + return template -def get_template_from_string(source, origin=None): +def get_template_from_string(source, origin=None, name=None): """ Returns a compiled Template object for the given template code, handling template inheritance recursively. """ - return Template(source, origin) + return Template(source, origin, name) def render_to_string(template_name, dictionary=None, context_instance=None): """ diff --git a/django/template/loader_tags.py b/django/template/loader_tags.py index 7f22f207b6..b20c318f1f 100644 --- a/django/template/loader_tags.py +++ b/django/template/loader_tags.py @@ -57,7 +57,7 @@ class ExtendsNode(Node): except TemplateDoesNotExist: raise TemplateSyntaxError, "Template %r cannot be extended, because it doesn't exist" % parent else: - return get_template_from_string(source, origin) + return get_template_from_string(source, origin, parent) def render(self, context): compiled_parent = self.get_parent(context) diff --git a/django/test/client.py b/django/test/client.py index 871f6cfb9b..3dfe764a38 100644 --- a/django/test/client.py +++ b/django/test/client.py @@ -1,10 +1,9 @@ from cStringIO import StringIO -from django.contrib.admin.views.decorators import LOGIN_FORM_KEY, _encode_post_data from django.core.handlers.base import BaseHandler from django.core.handlers.wsgi import WSGIRequest from django.dispatch import dispatcher from django.http import urlencode, SimpleCookie -from django.template import signals +from django.test import signals from django.utils.functional import curry class ClientHandler(BaseHandler): @@ -96,7 +95,7 @@ class Client: HTML rendered to the end-user. """ def __init__(self, **defaults): - self.handler = TestHandler() + self.handler = ClientHandler() self.defaults = defaults self.cookie = SimpleCookie() @@ -126,7 +125,7 @@ class Client: data = {} on_template_render = curry(store_rendered_templates, data) dispatcher.connect(on_template_render, signal=signals.template_rendered) - + response = self.handler(environ) # Add any rendered template detail to the response @@ -180,29 +179,38 @@ class Client: def login(self, path, username, password, **extra): """ A specialized sequence of GET and POST to log into a view that - is protected by @login_required or a similar access decorator. + is protected by a @login_required access decorator. - path should be the URL of the login page, or of any page that - is login protected. + path should be the URL of the page that is login protected. - Returns True if login was successful; False if otherwise. + Returns the response from GETting the requested URL after + login is complete. Returns False if login process failed. """ - # First, GET the login page. - # This is required to establish the session. + # First, GET the page that is login protected. + # This page will redirect to the login page. response = self.get(path) + if response.status_code != 302: + return False + + login_path, data = response['Location'].split('?') + next = data.split('=')[1] + + # Second, GET the login page; required to set up cookies + response = self.get(login_path, **extra) if response.status_code != 200: return False - - # Set up the block of form data required by the login page. + + # Last, POST the login data. form_data = { 'username': username, 'password': password, - 'this_is_the_login_form': 1, - 'post_data': _encode_post_data({LOGIN_FORM_KEY: 1}) + 'next' : next, } - response = self.post(path, data=form_data, **extra) - - # login page should give response 200 (if you requested the login - # page specifically), or 302 (if you requested a login - # protected page, to which the login can redirect). - return response.status_code in (200,302) + response = self.post(login_path, data=form_data, **extra) + + # Login page should 302 redirect to the originally requested page + if response.status_code != 302 or response['Location'] != path: + return False + + # Since we are logged in, request the actual page again + return self.get(path) diff --git a/django/test/signals.py b/django/test/signals.py new file mode 100644 index 0000000000..40748ff4fe --- /dev/null +++ b/django/test/signals.py @@ -0,0 +1 @@ +template_rendered = object() \ No newline at end of file diff --git a/django/test/simple.py b/django/test/simple.py index 2469f80b3a..043787414e 100644 --- a/django/test/simple.py +++ b/django/test/simple.py @@ -1,6 +1,7 @@ import unittest, doctest from django.conf import settings from django.core import management +from django.test.utils import setup_test_environment, teardown_test_environment from django.test.utils import create_test_db, destroy_test_db from django.test.testcases import OutputChecker, DocTestRunner @@ -51,6 +52,7 @@ def run_tests(module_list, verbosity=1, extra_tests=[]): the module. A list of 'extra' tests may also be provided; these tests will be added to the test suite. """ + setup_test_environment() settings.DEBUG = False suite = unittest.TestSuite() @@ -66,3 +68,5 @@ def run_tests(module_list, verbosity=1, extra_tests=[]): management.syncdb(verbosity, interactive=False) unittest.TextTestRunner(verbosity=verbosity).run(suite) destroy_test_db(old_name, verbosity) + + teardown_test_environment() diff --git a/django/test/utils.py b/django/test/utils.py index 6d6827918d..1bb448f41b 100644 --- a/django/test/utils.py +++ b/django/test/utils.py @@ -1,11 +1,40 @@ import sys, time from django.conf import settings from django.db import connection, transaction, backend +from django.dispatch import dispatcher +from django.test import signals +from django.template import Template # The prefix to put on the default database name when creating # the test database. TEST_DATABASE_PREFIX = 'test_' +def instrumented_test_render(self, context): + """An instrumented Template render method, providing a signal + that can be intercepted by the test system Client + + """ + dispatcher.send(signal=signals.template_rendered, sender=self, template=self, context=context) + return self.nodelist.render(context) + +def setup_test_environment(): + """Perform any global pre-test setup. This involves: + + - Installing the instrumented test renderer + + """ + Template.original_render = Template.render + Template.render = instrumented_test_render + +def teardown_test_environment(): + """Perform any global post-test teardown. This involves: + + - Restoring the original test renderer + + """ + Template.render = Template.original_render + del Template.original_render + def _set_autocommit(connection): "Make sure a connection is in autocommit mode." if hasattr(connection.connection, "autocommit"): @@ -75,4 +104,3 @@ def destroy_test_db(old_database_name, verbosity=1): time.sleep(1) # To avoid "database is being accessed by other users" errors. cursor.execute("DROP DATABASE %s" % backend.quote_name(TEST_DATABASE_NAME)) connection.close() - diff --git a/django/views/debug.py b/django/views/debug.py index 6934360afd..f15eb8ff01 100644 --- a/django/views/debug.py +++ b/django/views/debug.py @@ -115,7 +115,7 @@ def technical_500_response(request, exc_type, exc_value, tb): 'function': '?', 'lineno': '?', }] - t = Template(TECHNICAL_500_TEMPLATE) + t = Template(TECHNICAL_500_TEMPLATE, name='Technical 500 Template') c = Context({ 'exception_type': exc_type.__name__, 'exception_value': exc_value, @@ -141,7 +141,7 @@ def technical_404_response(request, exception): # tried exists but is an empty list. The URLconf must've been empty. return empty_urlconf(request) - t = Template(TECHNICAL_404_TEMPLATE) + t = Template(TECHNICAL_404_TEMPLATE, name='Technical 404 Template') c = Context({ 'root_urlconf': settings.ROOT_URLCONF, 'urlpatterns': tried, @@ -154,7 +154,7 @@ def technical_404_response(request, exception): def empty_urlconf(request): "Create an empty URLconf 404 error response." - t = Template(EMPTY_URLCONF_TEMPLATE) + t = Template(EMPTY_URLCONF_TEMPLATE, name='Empty URLConf Template') c = Context({ 'project_name': settings.SETTINGS_MODULE.split('.')[0] }) diff --git a/django/views/static.py b/django/views/static.py index ac323944d0..a8c8328014 100644 --- a/django/views/static.py +++ b/django/views/static.py @@ -81,7 +81,7 @@ def directory_index(path, fullpath): try: t = loader.get_template('static/directory_index') except TemplateDoesNotExist: - t = Template(DEFAULT_DIRECTORY_INDEX_TEMPLATE) + t = Template(DEFAULT_DIRECTORY_INDEX_TEMPLATE, name='Default Directory Index Template') files = [] for f in os.listdir(fullpath): if not f.startswith('.'):