[1.8.x] Refs #14030 -- Improved expression support for python values

Backport of e2d6e14662 from master
This commit is contained in:
Josh Smeaton 2015-02-11 16:38:02 +11:00
parent 343c087533
commit a6ea62aeaf
7 changed files with 89 additions and 88 deletions

View File

@ -1,5 +1,4 @@
from django.db.models.aggregates import StdDev from django.db.models.aggregates import StdDev
from django.db.models.expressions import Value
from django.db.utils import ProgrammingError from django.db.utils import ProgrammingError
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -229,7 +228,7 @@ class BaseDatabaseFeatures(object):
def supports_stddev(self): def supports_stddev(self):
"""Confirm support for STDDEV and related stats functions.""" """Confirm support for STDDEV and related stats functions."""
try: try:
self.connection.ops.check_expression_support(StdDev(Value(1))) self.connection.ops.check_expression_support(StdDev(1))
return True return True
except NotImplementedError: except NotImplementedError:
return False return False

View File

@ -7,7 +7,7 @@ from django.db.backends import utils as backend_utils
from django.db.models import fields from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import Q, refs_aggregate from django.db.models.query_utils import Q, refs_aggregate
from django.utils import timezone from django.utils import six, timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -138,6 +138,13 @@ class BaseExpression(object):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
assert len(exprs) == 0 assert len(exprs) == 0
def _parse_expressions(self, *expressions):
return [
arg if hasattr(arg, 'resolve_expression') else (
F(arg) if isinstance(arg, six.string_types) else Value(arg)
) for arg in expressions
]
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
""" """
Responsible for returning a (sql, [params]) tuple to be included Responsible for returning a (sql, [params]) tuple to be included
@ -466,12 +473,6 @@ class Func(ExpressionNode):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.source_expressions = exprs self.source_expressions = exprs
def _parse_expressions(self, *expressions):
return [
arg if hasattr(arg, 'resolve_expression') else F(arg)
for arg in expressions
]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy() c = self.copy()
c.is_summary = summarize c.is_summary = summarize
@ -639,14 +640,14 @@ class Ref(ExpressionNode):
class When(ExpressionNode): class When(ExpressionNode):
template = 'WHEN %(condition)s THEN %(result)s' template = 'WHEN %(condition)s THEN %(result)s'
def __init__(self, condition=None, then=Value(None), **lookups): def __init__(self, condition=None, then=None, **lookups):
if lookups and condition is None: if lookups and condition is None:
condition, lookups = Q(**lookups), None condition, lookups = Q(**lookups), None
if condition is None or not isinstance(condition, Q) or lookups: if condition is None or not isinstance(condition, Q) or lookups:
raise TypeError("__init__() takes either a Q object or lookups as keyword arguments") raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
super(When, self).__init__(output_field=None) super(When, self).__init__(output_field=None)
self.condition = condition self.condition = condition
self.result = self._parse_expression(then) self.result = self._parse_expressions(then)[0]
def __str__(self): def __str__(self):
return "WHEN %r THEN %r" % (self.condition, self.result) return "WHEN %r THEN %r" % (self.condition, self.result)
@ -664,9 +665,6 @@ class When(ExpressionNode):
# We're only interested in the fields of the result expressions. # We're only interested in the fields of the result expressions.
return [self.result._output_field_or_none] return [self.result._output_field_or_none]
def _parse_expression(self, expression):
return expression if hasattr(expression, 'resolve_expression') else F(expression)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy() c = self.copy()
c.is_summary = summarize c.is_summary = summarize
@ -713,11 +711,11 @@ class Case(ExpressionNode):
def __init__(self, *cases, **extra): def __init__(self, *cases, **extra):
if not all(isinstance(case, When) for case in cases): if not all(isinstance(case, When) for case in cases):
raise TypeError("Positional arguments must all be When objects.") raise TypeError("Positional arguments must all be When objects.")
default = extra.pop('default', Value(None)) default = extra.pop('default', None)
output_field = extra.pop('output_field', None) output_field = extra.pop('output_field', None)
super(Case, self).__init__(output_field) super(Case, self).__init__(output_field)
self.cases = list(cases) self.cases = list(cases)
self.default = default if hasattr(default, 'resolve_expression') else F(default) self.default = self._parse_expressions(default)[0]
def __str__(self): def __str__(self):
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default) return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)

View File

@ -39,7 +39,7 @@ We'll be using the following model in the subsequent examples::
When When
---- ----
.. class:: When(condition=None, then=Value(None), **lookups) .. class:: When(condition=None, then=None, **lookups)
A ``When()`` object is used to encapsulate a condition and its result for use A ``When()`` object is used to encapsulate a condition and its result for use
in the conditional expression. Using a ``When()`` object is similar to using in the conditional expression. Using a ``When()`` object is similar to using
@ -73,8 +73,8 @@ Keep in mind that each of these values can be an expression.
resolved in two ways:: resolved in two ways::
>>> from django.db.models import Value >>> from django.db.models import Value
>>> When(then__exact=0, then=Value(1)) >>> When(then__exact=0, then=1)
>>> When(Q(then=0), then=Value(1)) >>> When(Q(then=0), then=1)
Case Case
---- ----
@ -197,15 +197,15 @@ What if we want to find out how many clients there are for each
>>> from django.db.models import IntegerField, Sum >>> from django.db.models import IntegerField, Sum
>>> Client.objects.aggregate( >>> Client.objects.aggregate(
... regular=Sum( ... regular=Sum(
... Case(When(account_type=Client.REGULAR, then=Value(1)), ... Case(When(account_type=Client.REGULAR, then=1),
... output_field=IntegerField()) ... output_field=IntegerField())
... ), ... ),
... gold=Sum( ... gold=Sum(
... Case(When(account_type=Client.GOLD, then=Value(1)), ... Case(When(account_type=Client.GOLD, then=1),
... output_field=IntegerField()) ... output_field=IntegerField())
... ), ... ),
... platinum=Sum( ... platinum=Sum(
... Case(When(account_type=Client.PLATINUM, then=Value(1)), ... Case(When(account_type=Client.PLATINUM, then=1),
... output_field=IntegerField()) ... output_field=IntegerField())
... ) ... )
... ) ... )

View File

@ -221,6 +221,10 @@ function will be applied to. The expressions will be converted to strings,
joined together with ``arg_joiner``, and then interpolated into the ``template`` joined together with ``arg_joiner``, and then interpolated into the ``template``
as the ``expressions`` placeholder. as the ``expressions`` placeholder.
Positional arguments can be expressions or Python values. Strings are
assumed to be column references and will be wrapped in ``F()`` expressions
while other values will be wrapped in ``Value()`` expressions.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. Note that the keywords ``function`` and into the ``template`` attribute. Note that the keywords ``function`` and
``template`` can be used to replace the ``function`` and ``template`` ``template`` can be used to replace the ``function`` and ``template``

View File

@ -8,7 +8,7 @@ from django.core.exceptions import FieldError
from django.db import connection from django.db import connection
from django.db.models import ( from django.db.models import (
F, Aggregate, Avg, Count, DecimalField, FloatField, Func, IntegerField, F, Aggregate, Avg, Count, DecimalField, FloatField, Func, IntegerField,
Max, Min, Sum, Value, Max, Min, Sum,
) )
from django.test import TestCase, ignore_warnings from django.test import TestCase, ignore_warnings
from django.test.utils import Approximate, CaptureQueriesContext from django.test.utils import Approximate, CaptureQueriesContext
@ -706,14 +706,14 @@ class ComplexAggregateTestCase(TestCase):
Book.objects.aggregate(fail=F('price')) Book.objects.aggregate(fail=F('price'))
def test_nonfield_annotation(self): def test_nonfield_annotation(self):
book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0] book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0]
self.assertEqual(book.val, 2) self.assertEqual(book.val, 2)
book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0] book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0]
self.assertEqual(book.val, 2) self.assertEqual(book.val, 2)
def test_missing_output_field_raises_error(self): def test_missing_output_field_raises_error(self):
with six.assertRaisesRegex(self, FieldError, 'Cannot resolve expression type, unknown output_field'): with six.assertRaisesRegex(self, FieldError, 'Cannot resolve expression type, unknown output_field'):
Book.objects.annotate(val=Max(Value(2)))[0] Book.objects.annotate(val=Max(2))[0]
def test_annotation_expressions(self): def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
@ -772,7 +772,7 @@ class ComplexAggregateTestCase(TestCase):
with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'): with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'):
Author.objects.aggregate(Sum('age') / Count('age')) Author.objects.aggregate(Sum('age') / Count('age'))
with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'): with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'):
Author.objects.aggregate(Sum(Value(1))) Author.objects.aggregate(Sum(1))
def test_aggregate_over_complex_annotation(self): def test_aggregate_over_complex_annotation(self):
qs = Author.objects.annotate( qs = Author.objects.annotate(

View File

@ -281,7 +281,7 @@ class FunctionTests(TestCase):
def test_substr_with_expressions(self): def test_substr_with_expressions(self):
Author.objects.create(name='John Smith', alias='smithj') Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda') Author.objects.create(name='Rhonda')
authors = Author.objects.annotate(name_part=Substr('name', V(5), V(3))) authors = Author.objects.annotate(name_part=Substr('name', 5, 3))
self.assertQuerysetEqual( self.assertQuerysetEqual(
authors.order_by('name'), [ authors.order_by('name'), [
' Sm', ' Sm',

View File

@ -78,8 +78,8 @@ class CaseExpressionTests(TestCase):
def test_annotate_without_default(self): def test_annotate_without_default(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.annotate(test=Case( CaseTestModel.objects.annotate(test=Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
When(integer=2, then=Value(2)), When(integer=2, then=2),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)],
@ -244,9 +244,9 @@ class CaseExpressionTests(TestCase):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.annotate( CaseTestModel.objects.annotate(
test=Case( test=Case(
When(integer=1, then=Value(2)), When(integer=1, then=2),
When(integer=2, then=Value(1)), When(integer=2, then=1),
default=Value(3), default=3,
output_field=models.IntegerField(), output_field=models.IntegerField(),
) + 1, ) + 1,
).order_by('pk'), ).order_by('pk'),
@ -278,19 +278,19 @@ class CaseExpressionTests(TestCase):
self.assertEqual( self.assertEqual(
CaseTestModel.objects.aggregate( CaseTestModel.objects.aggregate(
one=models.Sum(Case( one=models.Sum(Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
two=models.Sum(Case( two=models.Sum(Case(
When(integer=2, then=Value(1)), When(integer=2, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
three=models.Sum(Case( three=models.Sum(Case(
When(integer=3, then=Value(1)), When(integer=3, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
four=models.Sum(Case( four=models.Sum(Case(
When(integer=4, then=Value(1)), When(integer=4, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
), ),
@ -311,11 +311,11 @@ class CaseExpressionTests(TestCase):
self.assertEqual( self.assertEqual(
CaseTestModel.objects.aggregate( CaseTestModel.objects.aggregate(
equal=models.Sum(Case( equal=models.Sum(Case(
When(integer2=F('integer'), then=Value(1)), When(integer2=F('integer'), then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
plus_one=models.Sum(Case( plus_one=models.Sum(Case(
When(integer2=F('integer') + 1, then=Value(1)), When(integer2=F('integer') + 1, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
), ),
@ -325,9 +325,9 @@ class CaseExpressionTests(TestCase):
def test_filter(self): def test_filter(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer2=Case( CaseTestModel.objects.filter(integer2=Case(
When(integer=2, then=Value(3)), When(integer=2, then=3),
When(integer=3, then=Value(4)), When(integer=3, then=4),
default=Value(1), default=1,
output_field=models.IntegerField(), output_field=models.IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)], [(1, 1), (2, 3), (3, 4), (3, 4)],
@ -337,8 +337,8 @@ class CaseExpressionTests(TestCase):
def test_filter_without_default(self): def test_filter_without_default(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer2=Case( CaseTestModel.objects.filter(integer2=Case(
When(integer=2, then=Value(3)), When(integer=2, then=3),
When(integer=3, then=Value(4)), When(integer=3, then=4),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(2, 3), (3, 4), (3, 4)], [(2, 3), (3, 4), (3, 4)],
@ -381,8 +381,8 @@ class CaseExpressionTests(TestCase):
def test_filter_with_join_in_condition(self): def test_filter_with_join_in_condition(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer=Case( CaseTestModel.objects.filter(integer=Case(
When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), When(integer2=F('o2o_rel__integer') + 1, then=2),
When(integer2=F('o2o_rel__integer'), then=Value(3)), When(integer2=F('o2o_rel__integer'), then=3),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(2, 3), (3, 3)], [(2, 3), (3, 3)],
@ -392,9 +392,9 @@ class CaseExpressionTests(TestCase):
def test_filter_with_join_in_predicate(self): def test_filter_with_join_in_predicate(self):
self.assertQuerysetEqual( self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer2=Case( CaseTestModel.objects.filter(integer2=Case(
When(o2o_rel__integer=1, then=Value(1)), When(o2o_rel__integer=1, then=1),
When(o2o_rel__integer=2, then=Value(3)), When(o2o_rel__integer=2, then=3),
When(o2o_rel__integer=3, then=Value(4)), When(o2o_rel__integer=3, then=4),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)).order_by('pk'), )).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)], [(1, 1), (2, 3), (3, 4), (3, 4)],
@ -422,8 +422,8 @@ class CaseExpressionTests(TestCase):
f_plus_1=F('integer') + 1, f_plus_1=F('integer') + 1,
).filter( ).filter(
integer=Case( integer=Case(
When(integer2=F('integer'), then=Value(2)), When(integer2=F('integer'), then=2),
When(integer2=F('f_plus_1'), then=Value(3)), When(integer2=F('f_plus_1'), then=3),
output_field=models.IntegerField(), output_field=models.IntegerField(),
), ),
).order_by('pk'), ).order_by('pk'),
@ -437,9 +437,9 @@ class CaseExpressionTests(TestCase):
f_plus_1=F('integer') + 1, f_plus_1=F('integer') + 1,
).filter( ).filter(
integer2=Case( integer2=Case(
When(f_plus_1=3, then=Value(3)), When(f_plus_1=3, then=3),
When(f_plus_1=4, then=Value(4)), When(f_plus_1=4, then=4),
default=Value(1), default=1,
output_field=models.IntegerField(), output_field=models.IntegerField(),
), ),
).order_by('pk'), ).order_by('pk'),
@ -469,8 +469,8 @@ class CaseExpressionTests(TestCase):
max=Max('fk_rel__integer'), max=Max('fk_rel__integer'),
).filter( ).filter(
integer=Case( integer=Case(
When(integer2=F('min'), then=Value(2)), When(integer2=F('min'), then=2),
When(integer2=F('max'), then=Value(3)), When(integer2=F('max'), then=3),
), ),
).order_by('pk'), ).order_by('pk'),
[(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)], [(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)],
@ -483,8 +483,8 @@ class CaseExpressionTests(TestCase):
max=Max('fk_rel__integer'), max=Max('fk_rel__integer'),
).filter( ).filter(
integer=Case( integer=Case(
When(max=3, then=Value(2)), When(max=3, then=2),
When(max=4, then=Value(3)), When(max=4, then=3),
), ),
).order_by('pk'), ).order_by('pk'),
[(2, 3, 3), (3, 4, 4), (2, 2, 3), (3, 4, 4), (3, 3, 4)], [(2, 3, 3), (3, 4, 4), (2, 2, 3), (3, 4, 4), (3, 3, 4)],
@ -508,8 +508,8 @@ class CaseExpressionTests(TestCase):
def test_update_without_default(self): def test_update_without_default(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
integer2=Case( integer2=Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
When(integer=2, then=Value(2)), When(integer=2, then=2),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -549,8 +549,8 @@ class CaseExpressionTests(TestCase):
with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'): with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'):
CaseTestModel.objects.update( CaseTestModel.objects.update(
integer=Case( integer=Case(
When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), When(integer2=F('o2o_rel__integer') + 1, then=2),
When(integer2=F('o2o_rel__integer'), then=Value(3)), When(integer2=F('o2o_rel__integer'), then=3),
output_field=models.IntegerField(), output_field=models.IntegerField(),
), ),
) )
@ -570,8 +570,8 @@ class CaseExpressionTests(TestCase):
def test_update_big_integer(self): def test_update_big_integer(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
big_integer=Case( big_integer=Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
When(integer=2, then=Value(2)), When(integer=2, then=2),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -599,9 +599,9 @@ class CaseExpressionTests(TestCase):
def test_update_boolean(self): def test_update_boolean(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
boolean=Case( boolean=Case(
When(integer=1, then=Value(True)), When(integer=1, then=True),
When(integer=2, then=Value(True)), When(integer=2, then=True),
default=Value(False), default=False,
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -627,8 +627,8 @@ class CaseExpressionTests(TestCase):
def test_update_date(self): def test_update_date(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
date=Case( date=Case(
When(integer=1, then=Value(date(2015, 1, 1))), When(integer=1, then=date(2015, 1, 1)),
When(integer=2, then=Value(date(2015, 1, 2))), When(integer=2, then=date(2015, 1, 2)),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -643,8 +643,8 @@ class CaseExpressionTests(TestCase):
def test_update_date_time(self): def test_update_date_time(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
date_time=Case( date_time=Case(
When(integer=1, then=Value(datetime(2015, 1, 1))), When(integer=1, then=datetime(2015, 1, 1)),
When(integer=2, then=Value(datetime(2015, 1, 2))), When(integer=2, then=datetime(2015, 1, 2)),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -659,8 +659,8 @@ class CaseExpressionTests(TestCase):
def test_update_decimal(self): def test_update_decimal(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
decimal=Case( decimal=Case(
When(integer=1, then=Value(Decimal('1.1'))), When(integer=1, then=Decimal('1.1')),
When(integer=2, then=Value(Decimal('2.2'))), When(integer=2, then=Decimal('2.2')),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -728,8 +728,8 @@ class CaseExpressionTests(TestCase):
def test_update_float(self): def test_update_float(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
float=Case( float=Case(
When(integer=1, then=Value(1.1)), When(integer=1, then=1.1),
When(integer=2, then=Value(2.2)), When(integer=2, then=2.2),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -785,8 +785,8 @@ class CaseExpressionTests(TestCase):
def test_update_null_boolean(self): def test_update_null_boolean(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
null_boolean=Case( null_boolean=Case(
When(integer=1, then=Value(True)), When(integer=1, then=True),
When(integer=2, then=Value(False)), When(integer=2, then=False),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -798,8 +798,8 @@ class CaseExpressionTests(TestCase):
def test_update_positive_integer(self): def test_update_positive_integer(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
positive_integer=Case( positive_integer=Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
When(integer=2, then=Value(2)), When(integer=2, then=2),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -811,8 +811,8 @@ class CaseExpressionTests(TestCase):
def test_update_positive_small_integer(self): def test_update_positive_small_integer(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
positive_small_integer=Case( positive_small_integer=Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
When(integer=2, then=Value(2)), When(integer=2, then=2),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -838,8 +838,8 @@ class CaseExpressionTests(TestCase):
def test_update_small_integer(self): def test_update_small_integer(self):
CaseTestModel.objects.update( CaseTestModel.objects.update(
small_integer=Case( small_integer=Case(
When(integer=1, then=Value(1)), When(integer=1, then=1),
When(integer=2, then=Value(2)), When(integer=2, then=2),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -936,8 +936,8 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.update( CaseTestModel.objects.update(
fk=Case( fk=Case(
When(integer=1, then=Value(obj1.pk)), When(integer=1, then=obj1.pk),
When(integer=2, then=Value(obj2.pk)), When(integer=2, then=obj2.pk),
), ),
) )
self.assertQuerysetEqual( self.assertQuerysetEqual(
@ -1080,15 +1080,15 @@ class CaseDocumentationExamples(TestCase):
self.assertEqual( self.assertEqual(
Client.objects.aggregate( Client.objects.aggregate(
regular=models.Sum(Case( regular=models.Sum(Case(
When(account_type=Client.REGULAR, then=Value(1)), When(account_type=Client.REGULAR, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
gold=models.Sum(Case( gold=models.Sum(Case(
When(account_type=Client.GOLD, then=Value(1)), When(account_type=Client.GOLD, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
platinum=models.Sum(Case( platinum=models.Sum(Case(
When(account_type=Client.PLATINUM, then=Value(1)), When(account_type=Client.PLATINUM, then=1),
output_field=models.IntegerField(), output_field=models.IntegerField(),
)), )),
), ),