From 1a5023883bf4c8cccb34a830edcc3c82aa862455 Mon Sep 17 00:00:00 2001 From: Matthijs Kooijman Date: Thu, 4 Nov 2021 18:24:19 +0100 Subject: [PATCH] Fixed #33257 -- Fixed Case() and ExpressionWrapper() with decimal values on SQLite. --- django/db/models/expressions.py | 4 ++-- tests/expressions/tests.py | 7 +++++++ tests/expressions_case/tests.py | 9 +++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index c7e33b3698..cb97740a6c 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -933,7 +933,7 @@ class ExpressionList(Func): return self.as_sql(compiler, connection, **extra_context) -class ExpressionWrapper(Expression): +class ExpressionWrapper(SQLiteNumericMixin, Expression): """ An expression that can wrap another expression so that it can provide extra context to the inner expression, such as the output_field. @@ -1032,7 +1032,7 @@ class When(Expression): return cols -class Case(Expression): +class Case(SQLiteNumericMixin, Expression): """ An SQL searched CASE expression: diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index 4bb65c9031..dab5474ef4 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -1178,6 +1178,13 @@ class ExpressionsNumericTests(TestCase): ordered=False ) + def test_filter_decimal_expression(self): + obj = Number.objects.create(integer=0, float=1, decimal_value=Decimal('1')) + qs = Number.objects.annotate( + x=ExpressionWrapper(Value(1), output_field=DecimalField()), + ).filter(Q(x=1, integer=0) & Q(x=Decimal('1'))) + self.assertSequenceEqual(qs, [obj]) + def test_complex_expressions(self): """ Complex expressions of different connection types are possible. diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index 24443ab3a1..7818e59dcf 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -256,6 +256,15 @@ class CaseExpressionTests(TestCase): transform=attrgetter('integer', 'test') ) + def test_annotate_filter_decimal(self): + obj = CaseTestModel.objects.create(integer=0, decimal=Decimal('1')) + qs = CaseTestModel.objects.annotate( + x=Case(When(integer=0, then=F('decimal'))), + y=Case(When(integer=0, then=Value(Decimal('1')))), + ) + self.assertSequenceEqual(qs.filter(Q(x=1) & Q(x=Decimal('1'))), [obj]) + self.assertSequenceEqual(qs.filter(Q(y=1) & Q(y=Decimal('1'))), [obj]) + def test_annotate_values_not_in_order_by(self): self.assertEqual( list(CaseTestModel.objects.annotate(test=Case(