diff --git a/django/db/models/lookups.py b/django/db/models/lookups.py index d0b30ace1d..8428d38d65 100644 --- a/django/db/models/lookups.py +++ b/django/db/models/lookups.py @@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin): bilateral_transforms.append((self.__class__, self.init_lookups)) return bilateral_transforms + @cached_property + def contains_aggregate(self): + return self.lhs.contains_aggregate + class Lookup(RegisterLookupMixin): lookup_name = None @@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin): def as_sql(self, compiler, connection): raise NotImplementedError + @cached_property + def contains_aggregate(self): + return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False) + class BuiltinLookup(Lookup): def process_lhs(self, compiler, connection, lhs=None): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 2ba6ceac33..8cad3df9a5 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -315,9 +315,9 @@ class WhereNode(tree.Node): @classmethod def _contains_aggregate(cls, obj): - if not isinstance(obj, tree.Node): - return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) - return any(cls._contains_aggregate(c) for c in obj.children) + if isinstance(obj, tree.Node): + return any(cls._contains_aggregate(c) for c in obj.children) + return obj.contains_aggregate @cached_property def contains_aggregate(self): @@ -336,6 +336,7 @@ class EverythingNode(object): """ A node that matches everything. """ + contains_aggregate = False def as_sql(self, compiler=None, connection=None): return '', [] @@ -345,11 +346,16 @@ class NothingNode(object): """ A node that matches nothing. """ + contains_aggregate = False + def as_sql(self, compiler=None, connection=None): raise EmptyResultSet class ExtraWhere(object): + # The contents are a black box - assume no aggregates are used. + contains_aggregate = False + def __init__(self, sqls, params): self.sqls = sqls self.params = params @@ -410,6 +416,10 @@ class Constraint(object): class SubqueryConstraint(object): + # Even if aggregates would be used in a subquery, the outer query isn't + # interested about those. + contains_aggregate = False + def __init__(self, alias, columns, targets, query_object): self.alias = alias self.columns = columns diff --git a/docs/releases/1.8.9.txt b/docs/releases/1.8.9.txt index 080d6506e8..8c5445b549 100644 --- a/docs/releases/1.8.9.txt +++ b/docs/releases/1.8.9.txt @@ -23,3 +23,6 @@ Bugfixes ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that already had the other specified, or when removing one of them from a field that had both (:ticket:`26034`). + +* Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression + (:ticket:`26071`). diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index 27aef931c0..de1859b368 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -8,7 +8,7 @@ from uuid import UUID from django.core.exceptions import FieldError from django.db import connection, models -from django.db.models import F, Q, Max, Min, Value +from django.db.models import F, Q, Max, Min, Sum, Value from django.db.models.expressions import Case, When from django.test import TestCase from django.utils import six @@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase): transform=attrgetter('integer', 'join_test') ) + def test_annotate_with_in_clause(self): + fk_rels = FKCaseTestModel.objects.filter(integer__in=[5]) + self.assertQuerysetEqual( + CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case( + When(fk_rel__in=fk_rels, then=F('fk_rel__integer')), + default=Value(0), + ))).order_by('pk'), + [(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)], + transform=attrgetter('integer', 'in_test') + ) + def test_annotate_with_join_in_condition(self): self.assertQuerysetEqual( CaseTestModel.objects.annotate(join_test=Case(