From 1410616e0e2637e4b0821620202bf62edc903b88 Mon Sep 17 00:00:00 2001 From: Johannes Dollinger Date: Fri, 5 Aug 2016 08:58:36 -0400 Subject: [PATCH] Fixed #26433 -- Fixed Case expressions with empty When. --- django/db/models/expressions.py | 10 ++++++++-- tests/expressions_case/tests.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index f79865d153..8e84be68cf 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -836,6 +836,7 @@ class Case(Expression): return c def as_sql(self, compiler, connection, template=None, case_joiner=None, **extra_context): + from django.db.models.sql.datastructures import EmptyResultSet connection.ops.check_expression_support(self) if not self.cases: return compiler.compile(self.default) @@ -844,12 +845,17 @@ class Case(Expression): case_parts = [] sql_params = [] for case in self.cases: - case_sql, case_params = compiler.compile(case) + try: + case_sql, case_params = compiler.compile(case) + except EmptyResultSet: + continue case_parts.append(case_sql) sql_params.extend(case_params) + default_sql, default_params = compiler.compile(self.default) + if not case_parts: + return default_sql, default_params case_joiner = case_joiner or self.case_joiner template_params['cases'] = case_joiner.join(case_parts) - default_sql, default_params = compiler.compile(self.default) template_params['default'] = default_sql sql_params.extend(default_params) template = template or template_params.get('template', self.template) diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index 30e8a704ae..c3b754fcae 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -275,6 +275,16 @@ class CaseExpressionTests(TestCase): [1, 4, 3, 3, 3, 2, 2] ) + def test_annotate_with_empty_when(self): + objects = CaseTestModel.objects.annotate( + selected=Case( + When(pk__in=[], then=Value('selected')), + default=Value('not selected'), output_field=models.CharField() + ) + ) + self.assertEqual(len(objects), CaseTestModel.objects.count()) + self.assertTrue(all(obj.selected == 'not selected' for obj in objects)) + def test_combined_expression(self): self.assertQuerysetEqual( CaseTestModel.objects.annotate(