Fixed #31606 -- Allowed using condition with lookups in When() expression.

This commit is contained in:
Ryan Heard 2020-05-19 00:47:56 -05:00 committed by Mariusz Felisiak
parent 2aac176e86
commit 587b179d41
4 changed files with 26 additions and 3 deletions

View File

@ -876,8 +876,11 @@ class When(Expression):
conditional = False
def __init__(self, condition=None, then=None, **lookups):
if lookups and condition is None:
condition, lookups = Q(**lookups), None
if lookups:
if condition is None:
condition, lookups = Q(**lookups), None
elif getattr(condition, 'conditional', False):
condition, lookups = Q(condition, **lookups), None
if condition is None or not getattr(condition, 'conditional', False) or lookups:
raise TypeError(
'When() supports a Q object, a boolean expression, or lookups '

View File

@ -81,6 +81,10 @@ Keep in mind that each of these values can be an expression.
>>> When(then__exact=0, then=1)
>>> When(Q(then=0), then=1)
.. versionchanged:: 3.2
Support for using the ``condition`` argument with ``lookups`` was added.
``Case``
--------

View File

@ -178,6 +178,9 @@ Models
supported on PostgreSQL, allows acquiring weaker locks that don't block the
creation of rows that reference locked rows through a foreign key.
* :class:`When() <django.db.models.expressions.When>` expression now allows
using the ``condition`` argument with ``lookups``.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -6,7 +6,7 @@ from uuid import UUID
from django.core.exceptions import FieldError
from django.db.models import (
BinaryField, Case, CharField, Count, DurationField, F,
BinaryField, BooleanField, Case, CharField, Count, DurationField, F,
GenericIPAddressField, IntegerField, Max, Min, Q, Sum, TextField,
TimeField, UUIDField, Value, When,
)
@ -312,6 +312,17 @@ class CaseExpressionTests(TestCase):
transform=attrgetter('integer', 'integer2')
)
def test_condition_with_lookups(self):
qs = CaseTestModel.objects.annotate(
test=Case(
When(Q(integer2=1), string='2', then=Value(False)),
When(Q(integer2=1), string='1', then=Value(True)),
default=Value(False),
output_field=BooleanField(),
),
)
self.assertIs(qs.get(integer=1).test, True)
def test_case_reuse(self):
SOME_CASE = Case(
When(pk=0, then=Value('0')),
@ -1350,6 +1361,8 @@ class CaseWhenTests(SimpleTestCase):
When(condition=object())
with self.assertRaisesMessage(TypeError, msg):
When(condition=Value(1, output_field=IntegerField()))
with self.assertRaisesMessage(TypeError, msg):
When(Value(1, output_field=IntegerField()), string='1')
with self.assertRaisesMessage(TypeError, msg):
When()