From a3e7d73ed7d90d31de46c60d40424267f62e411c Mon Sep 17 00:00:00 2001 From: Curtis Maloney Date: Tue, 16 Jul 2013 21:11:32 +1000 Subject: [PATCH] Allowed Context.push to behave as a context mananger. Thanks Loic Bistuer for the review. --- django/template/context.py | 23 ++++- django/template/defaulttags.py | 136 +++++++++++++-------------- django/template/loader.py | 5 +- django/template/loader_tags.py | 38 ++++---- docs/ref/templates/api.txt | 25 +++++ docs/releases/1.7.txt | 7 ++ tests/template_tests/test_context.py | 9 ++ 7 files changed, 145 insertions(+), 98 deletions(-) diff --git a/django/template/context.py b/django/template/context.py index 1ef7e889bc..3830c1c660 100644 --- a/django/template/context.py +++ b/django/template/context.py @@ -12,6 +12,21 @@ class ContextPopException(Exception): "pop() has been called more times than push()" pass + +class ContextDict(dict): + def __init__(self, context, *args, **kwargs): + super(ContextDict, self).__init__(*args, **kwargs) + + context.dicts.append(self) + self.context = context + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.context.pop() + + class BaseContext(object): def __init__(self, dict_=None): self._reset_dicts(dict_) @@ -34,10 +49,8 @@ class BaseContext(object): for d in reversed(self.dicts): yield d - def push(self): - d = {} - self.dicts.append(d) - return d + def push(self, *args, **kwargs): + return ContextDict(self, *args, **kwargs) def pop(self): if len(self.dicts) == 1: @@ -83,6 +96,7 @@ class BaseContext(object): new_context._reset_dicts(values) return new_context + class Context(BaseContext): "A stack container for variable context" def __init__(self, dict_=None, autoescape=True, current_app=None, @@ -106,6 +120,7 @@ class Context(BaseContext): self.dicts.append(other_dict) return other_dict + class RenderContext(BaseContext): """ A stack container for storing Template state. diff --git a/django/template/defaulttags.py b/django/template/defaulttags.py index 5b2a1b9501..5c9490f749 100644 --- a/django/template/defaulttags.py +++ b/django/template/defaulttags.py @@ -95,10 +95,9 @@ class FilterNode(Node): def render(self, context): output = self.nodelist.render(context) # Apply filters. - context.update({'var': output}) - filtered = self.filter_expr.resolve(context) - context.pop() - return filtered + with context.push(var=output): + return self.filter_expr.resolve(context) + class FirstOfNode(Node): def __init__(self, variables, escape=False): @@ -143,71 +142,69 @@ class ForNode(Node): parentloop = context['forloop'] else: parentloop = {} - context.push() - try: - values = self.sequence.resolve(context, True) - except VariableDoesNotExist: - values = [] - if values is None: - values = [] - if not hasattr(values, '__len__'): - values = list(values) - len_values = len(values) - if len_values < 1: - context.pop() - return self.nodelist_empty.render(context) - nodelist = NodeList() - if self.is_reversed: - values = reversed(values) - unpack = len(self.loopvars) > 1 - # Create a forloop value in the context. We'll update counters on each - # iteration just below. - loop_dict = context['forloop'] = {'parentloop': parentloop} - for i, item in enumerate(values): - # Shortcuts for current loop iteration number. - loop_dict['counter0'] = i - loop_dict['counter'] = i+1 - # Reverse counter iteration numbers. - loop_dict['revcounter'] = len_values - i - loop_dict['revcounter0'] = len_values - i - 1 - # Boolean values designating first and last times through loop. - loop_dict['first'] = (i == 0) - loop_dict['last'] = (i == len_values - 1) + with context.push(): + try: + values = self.sequence.resolve(context, True) + except VariableDoesNotExist: + values = [] + if values is None: + values = [] + if not hasattr(values, '__len__'): + values = list(values) + len_values = len(values) + if len_values < 1: + return self.nodelist_empty.render(context) + nodelist = NodeList() + if self.is_reversed: + values = reversed(values) + unpack = len(self.loopvars) > 1 + # Create a forloop value in the context. We'll update counters on each + # iteration just below. + loop_dict = context['forloop'] = {'parentloop': parentloop} + for i, item in enumerate(values): + # Shortcuts for current loop iteration number. + loop_dict['counter0'] = i + loop_dict['counter'] = i+1 + # Reverse counter iteration numbers. + loop_dict['revcounter'] = len_values - i + loop_dict['revcounter0'] = len_values - i - 1 + # Boolean values designating first and last times through loop. + loop_dict['first'] = (i == 0) + loop_dict['last'] = (i == len_values - 1) - pop_context = False - if unpack: - # If there are multiple loop variables, unpack the item into - # them. - try: - unpacked_vars = dict(zip(self.loopvars, item)) - except TypeError: - pass - else: - pop_context = True - context.update(unpacked_vars) - else: - context[self.loopvars[0]] = item - # In TEMPLATE_DEBUG mode provide source of the node which - # actually raised the exception - if settings.TEMPLATE_DEBUG: - for node in self.nodelist_loop: + pop_context = False + if unpack: + # If there are multiple loop variables, unpack the item into + # them. try: + unpacked_vars = dict(zip(self.loopvars, item)) + except TypeError: + pass + else: + pop_context = True + context.update(unpacked_vars) + else: + context[self.loopvars[0]] = item + # In TEMPLATE_DEBUG mode provide source of the node which + # actually raised the exception + if settings.TEMPLATE_DEBUG: + for node in self.nodelist_loop: + try: + nodelist.append(node.render(context)) + except Exception as e: + if not hasattr(e, 'django_template_source'): + e.django_template_source = node.source + raise + else: + for node in self.nodelist_loop: nodelist.append(node.render(context)) - except Exception as e: - if not hasattr(e, 'django_template_source'): - e.django_template_source = node.source - raise - else: - for node in self.nodelist_loop: - nodelist.append(node.render(context)) - if pop_context: - # The loop variables were pushed on to the context so pop them - # off again. This is necessary because the tag lets the length - # of loopvars differ to the length of each set of items and we - # don't want to leave any vars from the previous loop on the - # context. - context.pop() - context.pop() + if pop_context: + # The loop variables were pushed on to the context so pop them + # off again. This is necessary because the tag lets the length + # of loopvars differ to the length of each set of items and we + # don't want to leave any vars from the previous loop on the + # context. + context.pop() return nodelist.render(context) class IfChangedNode(Node): @@ -500,10 +497,9 @@ class WithNode(Node): def render(self, context): values = dict([(key, val.resolve(context)) for key, val in six.iteritems(self.extra_context)]) - context.update(values) - output = self.nodelist.render(context) - context.pop() - return output + with context.push(**values): + return self.nodelist.render(context) + @register.tag def autoescape(parser, token): diff --git a/django/template/loader.py b/django/template/loader.py index 6df4e43c4f..44b8f600fb 100644 --- a/django/template/loader.py +++ b/django/template/loader.py @@ -164,11 +164,8 @@ def render_to_string(template_name, dictionary=None, context_instance=None): return t.render(Context(dictionary)) # Add the dictionary to the context stack, ensuring it gets removed again # to keep the context_instance in the same state it started in. - context_instance.update(dictionary) - try: + with context_instance.push(dictionary): return t.render(context_instance) - finally: - context_instance.pop() def select_template(template_name_list): "Given a list of template names, returns the first that can be loaded." diff --git a/django/template/loader_tags.py b/django/template/loader_tags.py index 767f0e5ff8..406775da9d 100644 --- a/django/template/loader_tags.py +++ b/django/template/loader_tags.py @@ -47,22 +47,21 @@ class BlockNode(Node): def render(self, context): block_context = context.render_context.get(BLOCK_CONTEXT_KEY) - context.push() - if block_context is None: - context['block'] = self - result = self.nodelist.render(context) - else: - push = block = block_context.pop(self.name) - if block is None: - block = self - # Create new block so we can store context without thread-safety issues. - block = BlockNode(block.name, block.nodelist) - block.context = context - context['block'] = block - result = block.nodelist.render(context) - if push is not None: - block_context.push(self.name, push) - context.pop() + with context.push(): + if block_context is None: + context['block'] = self + result = self.nodelist.render(context) + else: + push = block = block_context.pop(self.name) + if block is None: + block = self + # Create new block so we can store context without thread-safety issues. + block = BlockNode(block.name, block.nodelist) + block.context = context + context['block'] = block + result = block.nodelist.render(context) + if push is not None: + block_context.push(self.name, push) return result def super(self): @@ -133,10 +132,9 @@ class BaseIncludeNode(Node): in six.iteritems(self.extra_context)]) if self.isolated_context: return template.render(context.new(values)) - context.update(values) - output = template.render(context) - context.pop() - return output + with context.push(**values): + return template.render(context) + class ConstantIncludeNode(BaseIncludeNode): def __init__(self, template_path, *args, **kwargs): diff --git a/docs/ref/templates/api.txt b/docs/ref/templates/api.txt index 6a9efc0811..f7dd0121d1 100644 --- a/docs/ref/templates/api.txt +++ b/docs/ref/templates/api.txt @@ -325,6 +325,31 @@ If you ``pop()`` too much, it'll raise ... django.template.ContextPopException +.. versionadded:: 1.7 + +You can also use ``push()`` as a context manager to ensure a matching ``pop()`` +is called. + + >>> c = Context() + >>> c['foo'] = 'first level' + >>> with c.push(): + >>> c['foo'] = 'second level' + >>> c['foo'] + 'second level' + >>> c['foo'] + 'first level' + +All arguments passed to ``push()`` will be passed to the ``dict`` constructor +used to build the new context level. + + >>> c = Context() + >>> c['foo'] = 'first level' + >>> with c.push(foo='second level'): + >>> c['foo'] + 'second level' + >>> c['foo'] + 'first level' + .. method:: update(other_dict) In addition to ``push()`` and ``pop()``, the ``Context`` diff --git a/docs/releases/1.7.txt b/docs/releases/1.7.txt index ae2a80d7a3..f551828455 100644 --- a/docs/releases/1.7.txt +++ b/docs/releases/1.7.txt @@ -60,6 +60,13 @@ Minor features * :attr:`~django.db.models.Options.app_label` is no longer required for models that are defined in a ``models`` package within an app. +* The :meth:`Context.push() ` method now returns + a context manager which automatically calls :meth:`pop() + ` upon exiting the ``with`` statement. + Additionally, :meth:`push() ` now accepts + parameters that are passed to the ``dict`` constructor used to build the new + context level. + Backwards incompatible changes in 1.7 ===================================== diff --git a/tests/template_tests/test_context.py b/tests/template_tests/test_context.py index 224b94d060..ca167a73f3 100644 --- a/tests/template_tests/test_context.py +++ b/tests/template_tests/test_context.py @@ -16,3 +16,12 @@ class ContextTests(TestCase): self.assertEqual(c.pop(), {"a": 2}) self.assertEqual(c["a"], 1) self.assertEqual(c.get("foo", 42), 42) + + with c.push(): + c['a'] = 2 + self.assertEqual(c['a'], 2) + self.assertEqual(c['a'], 1) + + with c.push(a=3): + self.assertEqual(c['a'], 3) + self.assertEqual(c['a'], 1)