[1.8.x] Fixed #26071 -- Fixed crash with __in lookup in a Case expression.

Partial backport of afe0bb7b13 from master.
This commit is contained in:
Alberto Avila 2016-01-08 14:20:15 -06:00 committed by Tim Graham
parent e625859f08
commit 5b3c66d8b6
4 changed files with 36 additions and 4 deletions

View File

@ -92,6 +92,10 @@ class Transform(RegisterLookupMixin):
bilateral_transforms.append((self.__class__, self.init_lookups)) bilateral_transforms.append((self.__class__, self.init_lookups))
return bilateral_transforms return bilateral_transforms
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate
class Lookup(RegisterLookupMixin): class Lookup(RegisterLookupMixin):
lookup_name = None lookup_name = None
@ -194,6 +198,10 @@ class Lookup(RegisterLookupMixin):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
raise NotImplementedError raise NotImplementedError
@cached_property
def contains_aggregate(self):
return self.lhs.contains_aggregate or getattr(self.rhs, 'contains_aggregate', False)
class BuiltinLookup(Lookup): class BuiltinLookup(Lookup):
def process_lhs(self, compiler, connection, lhs=None): def process_lhs(self, compiler, connection, lhs=None):

View File

@ -315,9 +315,9 @@ class WhereNode(tree.Node):
@classmethod @classmethod
def _contains_aggregate(cls, obj): def _contains_aggregate(cls, obj):
if not isinstance(obj, tree.Node): if isinstance(obj, tree.Node):
return getattr(obj.lhs, 'contains_aggregate', False) or getattr(obj.rhs, 'contains_aggregate', False) return any(cls._contains_aggregate(c) for c in obj.children)
return any(cls._contains_aggregate(c) for c in obj.children) return obj.contains_aggregate
@cached_property @cached_property
def contains_aggregate(self): def contains_aggregate(self):
@ -336,6 +336,7 @@ class EverythingNode(object):
""" """
A node that matches everything. A node that matches everything.
""" """
contains_aggregate = False
def as_sql(self, compiler=None, connection=None): def as_sql(self, compiler=None, connection=None):
return '', [] return '', []
@ -345,11 +346,16 @@ class NothingNode(object):
""" """
A node that matches nothing. A node that matches nothing.
""" """
contains_aggregate = False
def as_sql(self, compiler=None, connection=None): def as_sql(self, compiler=None, connection=None):
raise EmptyResultSet raise EmptyResultSet
class ExtraWhere(object): class ExtraWhere(object):
# The contents are a black box - assume no aggregates are used.
contains_aggregate = False
def __init__(self, sqls, params): def __init__(self, sqls, params):
self.sqls = sqls self.sqls = sqls
self.params = params self.params = params
@ -410,6 +416,10 @@ class Constraint(object):
class SubqueryConstraint(object): class SubqueryConstraint(object):
# Even if aggregates would be used in a subquery, the outer query isn't
# interested about those.
contains_aggregate = False
def __init__(self, alias, columns, targets, query_object): def __init__(self, alias, columns, targets, query_object):
self.alias = alias self.alias = alias
self.columns = columns self.columns = columns

View File

@ -23,3 +23,6 @@ Bugfixes
``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that ``db_index=True`` or ``unique=True`` to a ``CharField`` or ``TextField`` that
already had the other specified, or when removing one of them from a field already had the other specified, or when removing one of them from a field
that had both (:ticket:`26034`). that had both (:ticket:`26034`).
* Fixed a crash when using an ``__in`` lookup inside a ``Case`` expression
(:ticket:`26071`).

View File

@ -8,7 +8,7 @@ from uuid import UUID
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection, models from django.db import connection, models
from django.db.models import F, Q, Max, Min, Value from django.db.models import F, Q, Max, Min, Sum, Value
from django.db.models.expressions import Case, When from django.db.models.expressions import Case, When
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
@ -119,6 +119,17 @@ class CaseExpressionTests(TestCase):
transform=attrgetter('integer', 'join_test') transform=attrgetter('integer', 'join_test')
) )
def test_annotate_with_in_clause(self):
fk_rels = FKCaseTestModel.objects.filter(integer__in=[5])
self.assertQuerysetEqual(
CaseTestModel.objects.only('pk', 'integer').annotate(in_test=Sum(Case(
When(fk_rel__in=fk_rels, then=F('fk_rel__integer')),
default=Value(0),
))).order_by('pk'),
[(1, 0), (2, 0), (3, 0), (2, 0), (3, 0), (3, 0), (4, 5)],
transform=attrgetter('integer', 'in_test')
)
def test_annotate_with_join_in_condition(self): def test_annotate_with_join_in_condition(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.annotate(join_test=Case( CaseTestModel.objects.annotate(join_test=Case(