Fixed #26433 -- Fixed Case expressions with empty When.

This commit is contained in:
Johannes Dollinger 2016-08-05 08:58:36 -04:00 committed by Tim Graham
parent e7fb724cd2
commit 1410616e0e
2 changed files with 18 additions and 2 deletions

View File

@ -836,6 +836,7 @@ class Case(Expression):
return c return c
def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context): def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context):
from django.db.models.sql.datastructures import EmptyResultSet
connection.ops.check_expression_support(self) connection.ops.check_expression_support(self)
if not self.cases: if not self.cases:
return compiler.compile(self.default) return compiler.compile(self.default)
@ -844,12 +845,17 @@ class Case(Expression):
case_parts = [] case_parts = []
sql_params = [] sql_params = []
for case in self.cases: for case in self.cases:
try:
case_sql, case_params = compiler.compile(case) case_sql, case_params = compiler.compile(case)
except EmptyResultSet:
continue
case_parts.append(case_sql) case_parts.append(case_sql)
sql_params.extend(case_params) sql_params.extend(case_params)
default_sql, default_params = compiler.compile(self.default)
if not case_parts:
return default_sql, default_params
case_joiner = case_joiner or self.case_joiner case_joiner = case_joiner or self.case_joiner
template_params['cases'] = case_joiner.join(case_parts) template_params['cases'] = case_joiner.join(case_parts)
default_sql, default_params = compiler.compile(self.default)
template_params['default'] = default_sql template_params['default'] = default_sql
sql_params.extend(default_params) sql_params.extend(default_params)
template = template or template_params.get('template', self.template) template = template or template_params.get('template', self.template)

View File

@ -275,6 +275,16 @@ class CaseExpressionTests(TestCase):
[1, 4, 3, 3, 3, 2, 2] [1, 4, 3, 3, 3, 2, 2]
) )
def test_annotate_with_empty_when(self):
objects = CaseTestModel.objects.annotate(
selected=Case(
When(pk__in=[], then=Value('selected')),
default=Value('not selected'), output_field=models.CharField()
)
)
self.assertEqual(len(objects), CaseTestModel.objects.count())
self.assertTrue(all(obj.selected == 'not selected' for obj in objects))
def test_combined_expression(self): def test_combined_expression(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.annotate( CaseTestModel.objects.annotate(