Simplify FilterExpression.args_check

This commit is contained in:
Curtis Maloney 2013-09-04 20:53:55 +10:00 committed by Tim Graham
parent 6dca603abb
commit 7c6f2ddcd9
2 changed files with 47 additions and 25 deletions

View File

@ -622,34 +622,17 @@ class FilterExpression(object):
def args_check(name, func, provided): def args_check(name, func, provided):
provided = list(provided) provided = list(provided)
plen = len(provided) # First argument, filter input, is implied.
plen = len(provided) + 1
# Check to see if a decorator is providing the real function. # Check to see if a decorator is providing the real function.
func = getattr(func, '_decorated_function', func) func = getattr(func, '_decorated_function', func)
args, varargs, varkw, defaults = getargspec(func) args, varargs, varkw, defaults = getargspec(func)
# First argument is filter input. alen = len(args)
args.pop(0) dlen = len(defaults or [])
if defaults: # Not enough OR Too many
nondefs = args[:-len(defaults)] if plen < (alen - dlen) or plen > alen:
else:
nondefs = args
# Args without defaults must be provided.
try:
for arg in nondefs:
provided.pop(0)
except IndexError:
# Not enough
raise TemplateSyntaxError("%s requires %d arguments, %d provided" % raise TemplateSyntaxError("%s requires %d arguments, %d provided" %
(name, len(nondefs), plen)) (name, alen - dlen, plen))
# Defaults can be overridden.
defaults = list(defaults) if defaults else []
try:
for parg in provided:
defaults.pop(0)
except IndexError:
# Too many.
raise TemplateSyntaxError("%s requires %d arguments, %d provided" %
(name, len(nondefs), plen))
return True return True
args_check = staticmethod(args_check) args_check = staticmethod(args_check)

View File

@ -6,7 +6,7 @@ from __future__ import unicode_literals
from unittest import TestCase from unittest import TestCase
from django.template import (TokenParser, FilterExpression, Parser, Variable, from django.template import (TokenParser, FilterExpression, Parser, Variable,
Template, TemplateSyntaxError) Template, TemplateSyntaxError, Library)
from django.test.utils import override_settings from django.test.utils import override_settings
from django.utils import six from django.utils import six
@ -94,3 +94,42 @@ class ParserTests(TestCase):
with six.assertRaisesRegex(self, TemplateSyntaxError, msg) as cm: with six.assertRaisesRegex(self, TemplateSyntaxError, msg) as cm:
Template("{% if 1 %}{{ foo@bar }}{% endif %}") Template("{% if 1 %}{{ foo@bar }}{% endif %}")
self.assertEqual(cm.exception.django_template_source[1], (10, 23)) self.assertEqual(cm.exception.django_template_source[1], (10, 23))
def test_filter_args_count(self):
p = Parser("")
l = Library()
@l.filter
def no_arguments(value):
pass
@l.filter
def one_argument(value, arg):
pass
@l.filter
def one_opt_argument(value, arg=False):
pass
@l.filter
def two_arguments(value, arg, arg2):
pass
@l.filter
def two_one_opt_arg(value, arg, arg2=False):
pass
p.add_library(l)
for expr in (
'1|no_arguments:"1"',
'1|two_arguments',
'1|two_arguments:"1"',
'1|two_one_opt_arg',
):
with self.assertRaises(TemplateSyntaxError):
FilterExpression(expr, p)
for expr in (
# Correct number of arguments
'1|no_arguments',
'1|one_argument:"1"',
# One optional
'1|one_opt_argument',
'1|one_opt_argument:"1"',
# Not supplying all
'1|two_one_opt_arg:"1"',
):
FilterExpression(expr, p)