[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.

Partial backport of afe0bb7b13 from master.
This commit is contained in:
Alberto Avila 2016-01-08 14:20:15 -06:00 committed by Tim Graham
parent e625859f08
commit 5b3c66d8b6
4 changed files with 36 additions and 4 deletions

View File

@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin):
bilateral_transforms.append((self.__class__, self.init_lookups))
return bilateral_transforms
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate
class Lookup(RegisterLookupMixin):
lookup_name = None
@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin):
def as_sql(self, compiler, connection):
raise NotImplementedError
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None):

View File

@ -315,9 +315,9 @@ class WhereNode(tree.Node):
@classmethod
def _contains_aggregate(cls, obj):
if not isinstance(obj, tree.Node):
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False)
if isinstance(obj, tree.Node):
return any(cls._contains_aggregate(c) for c in obj.children)
return obj.contains_aggregate
@cached_property
def contains_aggregate(self):
@ -336,6 +336,7 @@ class EverythingNode(object):
"""
A node that matches everything.
"""
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):
return '', []
@ -345,11 +346,16 @@ class NothingNode(object):
"""
A node that matches nothing.
"""
contains_aggregate = False
def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet
class ExtraWhere(object):
# The contents are a black box - assume no aggregates are used.
contains_aggregate = False
def __init__(self, sqls, params):
self.sqls = sqls
self.params = params
@ -410,6 +416,10 @@ class Constraint(object):
class SubqueryConstraint(object):
# Even if aggregates would be used in a subquery, the outer query isn't
# interested about those.
contains_aggregate = False
def __init__(self, alias, columns, targets, query_object):
self.alias = alias
self.columns = columns

View File

@ -23,3 +23,6 @@ Bugfixes
``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that
already had the other specified, or when removing one of them from a field
that had both (:ticket:`26034`).
* Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression
(:ticket:`26071`).

View File

@ -8,7 +8,7 @@ from uuid import UUID
from django.core.exceptions import FieldError
from django.db import connection, models
from django.db.models import F, Q, Max, Min, Value
from django.db.models import F, Q, Max, Min, Sum, Value
from django.db.models.expressions import Case, When
from django.test import TestCase
from django.utils import six
@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase):
transform=attrgetter('integer', 'join_test')
)
def test_annotate_with_in_clause(self):
fk_rels = FKCaseTestModel.objects.filter(integer__in=[5])
self.assertQuerysetEqual(
CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case(
When(fk_rel__in=fk_rels, then=F('fk_rel__integer')),
default=Value(0),
))).order_by('pk'),
[(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)],
transform=attrgetter('integer', 'in_test')
)
def test_annotate_with_join_in_condition(self):
self.assertQuerysetEqual(
CaseTestModel.objects.annotate(join_test=Case(