From 6e228d0b659c9779da7d392a75c7d2fbb42fb3cb Mon Sep 17 00:00:00 2001 From: Mariusz Felisiak Date: Thu, 29 Jun 2017 18:25:36 +0200 Subject: [PATCH] Fixed #28277 -- Added validation of QuerySet.annotate() and aggregate() args. Thanks Tim Graham and Nick Pope for reviews. --- django/db/models/query.py | 13 +++++++++++++ tests/aggregation/tests.py | 9 +++++++++ tests/annotations/tests.py | 9 +++++++++ 3 files changed, 31 insertions(+) diff --git a/django/db/models/query.py b/django/db/models/query.py index b71e9a87d5..38f69f22d1 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -323,6 +323,7 @@ class QuerySet: """ if self.query.distinct_fields: raise NotImplementedError("aggregate() + distinct(fields) not implemented.") + self._validate_values_are_expressions(args + tuple(kwargs.values()), method_name='aggregate') for arg in args: # The default_alias property raises TypeError if default_alias # can't be set automatically or AttributeError if it isn't an @@ -895,6 +896,7 @@ class QuerySet: Return a query set in which the returned objects have been annotated with extra data or aggregations. """ + self._validate_values_are_expressions(args + tuple(kwargs.values()), method_name='annotate') annotations = OrderedDict() # To preserve ordering of args for arg in args: # The default_alias property may raise a TypeError. @@ -1145,6 +1147,17 @@ class QuerySet: """ return self.query.has_filters() + @staticmethod + def _validate_values_are_expressions(values, method_name): + invalid_args = sorted(str(arg) for arg in values if not hasattr(arg, 'resolve_expression')) + if invalid_args: + raise TypeError( + 'QuerySet.%s() received non-expression(s): %s.' % ( + method_name, + ', '.join(invalid_args), + ) + ) + class InstanceCheckMeta(type): def __instancecheck__(self, instance): diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 8cf0596347..c3f7998f60 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1186,3 +1186,12 @@ class AggregateTestCase(TestCase): ).filter(rating_or_num_awards__gt=F('num_awards')).order_by('num_awards') self.assertQuerysetEqual( qs2, [1, 3], lambda v: v.num_awards) + + def test_arguments_must_be_expressions(self): + msg = 'QuerySet.aggregate() received non-expression(s): %s.' + with self.assertRaisesMessage(TypeError, msg % FloatField()): + Book.objects.aggregate(FloatField()) + with self.assertRaisesMessage(TypeError, msg % True): + Book.objects.aggregate(is_book=True) + with self.assertRaisesMessage(TypeError, msg % ', '.join([str(FloatField()), 'True'])): + Book.objects.aggregate(FloatField(), Avg('price'), is_book=True) diff --git a/tests/annotations/tests.py b/tests/annotations/tests.py index 981e73e43e..1714076e1d 100644 --- a/tests/annotations/tests.py +++ b/tests/annotations/tests.py @@ -508,3 +508,12 @@ class NonAggregateAnnotationTestCase(TestCase): self.assertIs(book.is_book, True) self.assertIs(book.is_pony, False) self.assertIsNone(book.is_none) + + def test_arguments_must_be_expressions(self): + msg = 'QuerySet.annotate() received non-expression(s): %s.' + with self.assertRaisesMessage(TypeError, msg % BooleanField()): + Book.objects.annotate(BooleanField()) + with self.assertRaisesMessage(TypeError, msg % True): + Book.objects.annotate(is_book=True) + with self.assertRaisesMessage(TypeError, msg % ', '.join([str(BooleanField()), 'True'])): + Book.objects.annotate(BooleanField(), Value(False), is_book=True)