Refs #14030 -- Improved expression support for python values

This commit is contained in:
Josh Smeaton 2015-02-11 16:38:02 +11:00
parent 07cfe1bd82
commit e2d6e14662
7 changed files with 89 additions and 88 deletions

View File

@ -1,5 +1,4 @@
from django.db.models.aggregates import StdDev
from django.db.models.expressions import Value
from django.db.utils import ProgrammingError
from django.utils.functional import cached_property
@ -232,7 +231,7 @@ class BaseDatabaseFeatures(object):
def supports_stddev(self):
"""Confirm support for STDDEV and related stats functions."""
try:
self.connection.ops.check_expression_support(StdDev(Value(1)))
self.connection.ops.check_expression_support(StdDev(1))
return True
except NotImplementedError:
return False

View File

@ -7,7 +7,7 @@ from django.db.backends import utils as backend_utils
from django.db.models import fields
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import Q, refs_aggregate
from django.utils import timezone
from django.utils import six, timezone
from django.utils.functional import cached_property
@ -138,6 +138,13 @@ class BaseExpression(object):
def set_source_expressions(self, exprs):
assert len(exprs) == 0
def _parse_expressions(self, *expressions):
return [
arg if hasattr(arg, 'resolve_expression') else (
F(arg) if isinstance(arg, six.string_types) else Value(arg)
) for arg in expressions
]
def as_sql(self, compiler, connection):
"""
Responsible for returning a (sql, [params]) tuple to be included
@ -466,12 +473,6 @@ class Func(ExpressionNode):
def set_source_expressions(self, exprs):
self.source_expressions = exprs
def _parse_expressions(self, *expressions):
return [
arg if hasattr(arg, 'resolve_expression') else F(arg)
for arg in expressions
]
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
@ -639,14 +640,14 @@ class Ref(ExpressionNode):
class When(ExpressionNode):
template = 'WHEN %(condition)s THEN %(result)s'
def __init__(self, condition=None, then=Value(None), **lookups):
def __init__(self, condition=None, then=None, **lookups):
if lookups and condition is None:
condition, lookups = Q(**lookups), None
if condition is None or not isinstance(condition, Q) or lookups:
raise TypeError("__init__() takes either a Q object or lookups as keyword arguments")
super(When, self).__init__(output_field=None)
self.condition = condition
self.result = self._parse_expression(then)
self.result = self._parse_expressions(then)[0]
def __str__(self):
return "WHEN %r THEN %r" % (self.condition, self.result)
@ -664,9 +665,6 @@ class When(ExpressionNode):
# We're only interested in the fields of the result expressions.
return [self.result._output_field_or_none]
def _parse_expression(self, expression):
return expression if hasattr(expression, 'resolve_expression') else F(expression)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
c = self.copy()
c.is_summary = summarize
@ -713,11 +711,11 @@ class Case(ExpressionNode):
def __init__(self, *cases, **extra):
if not all(isinstance(case, When) for case in cases):
raise TypeError("Positional arguments must all be When objects.")
default = extra.pop('default', Value(None))
default = extra.pop('default', None)
output_field = extra.pop('output_field', None)
super(Case, self).__init__(output_field)
self.cases = list(cases)
self.default = default if hasattr(default, 'resolve_expression') else F(default)
self.default = self._parse_expressions(default)[0]
def __str__(self):
return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default)

View File

@ -39,7 +39,7 @@ We'll be using the following model in the subsequent examples::
When
----
.. class:: When(condition=None, then=Value(None), **lookups)
.. class:: When(condition=None, then=None, **lookups)
A ``When()`` object is used to encapsulate a condition and its result for use
in the conditional expression. Using a ``When()`` object is similar to using
@ -73,8 +73,8 @@ Keep in mind that each of these values can be an expression.
resolved in two ways::
>>> from django.db.models import Value
>>> When(then__exact=0, then=Value(1))
>>> When(Q(then=0), then=Value(1))
>>> When(then__exact=0, then=1)
>>> When(Q(then=0), then=1)
Case
----
@ -197,15 +197,15 @@ What if we want to find out how many clients there are for each
>>> from django.db.models import IntegerField, Sum
>>> Client.objects.aggregate(
... regular=Sum(
... Case(When(account_type=Client.REGULAR, then=Value(1)),
... Case(When(account_type=Client.REGULAR, then=1),
... output_field=IntegerField())
... ),
... gold=Sum(
... Case(When(account_type=Client.GOLD, then=Value(1)),
... Case(When(account_type=Client.GOLD, then=1),
... output_field=IntegerField())
... ),
... platinum=Sum(
... Case(When(account_type=Client.PLATINUM, then=Value(1)),
... Case(When(account_type=Client.PLATINUM, then=1),
... output_field=IntegerField())
... )
... )

View File

@ -217,6 +217,10 @@ function will be applied to. The expressions will be converted to strings,
joined together with ``arg_joiner``, and then interpolated into the ``template``
as the ``expressions`` placeholder.
Positional arguments can be expressions or Python values. Strings are
assumed to be column references and will be wrapped in ``F()`` expressions
while other values will be wrapped in ``Value()`` expressions.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. Note that the keywords ``function`` and
``template`` can be used to replace the ``function`` and ``template``

View File

@ -8,7 +8,7 @@ from django.core.exceptions import FieldError
from django.db import connection
from django.db.models import (
F, Aggregate, Avg, Count, DecimalField, FloatField, Func, IntegerField,
Max, Min, Sum, Value,
Max, Min, Sum,
)
from django.test import TestCase, ignore_warnings
from django.test.utils import Approximate, CaptureQueriesContext
@ -706,14 +706,14 @@ class ComplexAggregateTestCase(TestCase):
Book.objects.aggregate(fail=F('price'))
def test_nonfield_annotation(self):
book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0]
book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0]
self.assertEqual(book.val, 2)
book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0]
book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0]
self.assertEqual(book.val, 2)
def test_missing_output_field_raises_error(self):
with six.assertRaisesRegex(self, FieldError, 'Cannot resolve expression type, unknown output_field'):
Book.objects.annotate(val=Max(Value(2)))[0]
Book.objects.annotate(val=Max(2))[0]
def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
@ -772,7 +772,7 @@ class ComplexAggregateTestCase(TestCase):
with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'):
Author.objects.aggregate(Sum('age') / Count('age'))
with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'):
Author.objects.aggregate(Sum(Value(1)))
Author.objects.aggregate(Sum(1))
def test_aggregate_over_complex_annotation(self):
qs = Author.objects.annotate(

View File

@ -281,7 +281,7 @@ class FunctionTests(TestCase):
def test_substr_with_expressions(self):
Author.objects.create(name='John Smith', alias='smithj')
Author.objects.create(name='Rhonda')
authors = Author.objects.annotate(name_part=Substr('name', V(5), V(3)))
authors = Author.objects.annotate(name_part=Substr('name', 5, 3))
self.assertQuerysetEqual(
authors.order_by('name'), [
' Sm',

View File

@ -78,8 +78,8 @@ class CaseExpressionTests(TestCase):
def test_annotate_without_default(self):
self.assertQuerysetEqual(
CaseTestModel.objects.annotate(test=Case(
When(integer=1, then=Value(1)),
When(integer=2, then=Value(2)),
When(integer=1, then=1),
When(integer=2, then=2),
output_field=models.IntegerField(),
)).order_by('pk'),
[(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)],
@ -244,9 +244,9 @@ class CaseExpressionTests(TestCase):
self.assertQuerysetEqual(
CaseTestModel.objects.annotate(
test=Case(
When(integer=1, then=Value(2)),
When(integer=2, then=Value(1)),
default=Value(3),
When(integer=1, then=2),
When(integer=2, then=1),
default=3,
output_field=models.IntegerField(),
) + 1,
).order_by('pk'),
@ -278,19 +278,19 @@ class CaseExpressionTests(TestCase):
self.assertEqual(
CaseTestModel.objects.aggregate(
one=models.Sum(Case(
When(integer=1, then=Value(1)),
When(integer=1, then=1),
output_field=models.IntegerField(),
)),
two=models.Sum(Case(
When(integer=2, then=Value(1)),
When(integer=2, then=1),
output_field=models.IntegerField(),
)),
three=models.Sum(Case(
When(integer=3, then=Value(1)),
When(integer=3, then=1),
output_field=models.IntegerField(),
)),
four=models.Sum(Case(
When(integer=4, then=Value(1)),
When(integer=4, then=1),
output_field=models.IntegerField(),
)),
),
@ -311,11 +311,11 @@ class CaseExpressionTests(TestCase):
self.assertEqual(
CaseTestModel.objects.aggregate(
equal=models.Sum(Case(
When(integer2=F('integer'), then=Value(1)),
When(integer2=F('integer'), then=1),
output_field=models.IntegerField(),
)),
plus_one=models.Sum(Case(
When(integer2=F('integer') + 1, then=Value(1)),
When(integer2=F('integer') + 1, then=1),
output_field=models.IntegerField(),
)),
),
@ -325,9 +325,9 @@ class CaseExpressionTests(TestCase):
def test_filter(self):
self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer2=Case(
When(integer=2, then=Value(3)),
When(integer=3, then=Value(4)),
default=Value(1),
When(integer=2, then=3),
When(integer=3, then=4),
default=1,
output_field=models.IntegerField(),
)).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)],
@ -337,8 +337,8 @@ class CaseExpressionTests(TestCase):
def test_filter_without_default(self):
self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer2=Case(
When(integer=2, then=Value(3)),
When(integer=3, then=Value(4)),
When(integer=2, then=3),
When(integer=3, then=4),
output_field=models.IntegerField(),
)).order_by('pk'),
[(2, 3), (3, 4), (3, 4)],
@ -381,8 +381,8 @@ class CaseExpressionTests(TestCase):
def test_filter_with_join_in_condition(self):
self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer=Case(
When(integer2=F('o2o_rel__integer') + 1, then=Value(2)),
When(integer2=F('o2o_rel__integer'), then=Value(3)),
When(integer2=F('o2o_rel__integer') + 1, then=2),
When(integer2=F('o2o_rel__integer'), then=3),
output_field=models.IntegerField(),
)).order_by('pk'),
[(2, 3), (3, 3)],
@ -392,9 +392,9 @@ class CaseExpressionTests(TestCase):
def test_filter_with_join_in_predicate(self):
self.assertQuerysetEqual(
CaseTestModel.objects.filter(integer2=Case(
When(o2o_rel__integer=1, then=Value(1)),
When(o2o_rel__integer=2, then=Value(3)),
When(o2o_rel__integer=3, then=Value(4)),
When(o2o_rel__integer=1, then=1),
When(o2o_rel__integer=2, then=3),
When(o2o_rel__integer=3, then=4),
output_field=models.IntegerField(),
)).order_by('pk'),
[(1, 1), (2, 3), (3, 4), (3, 4)],
@ -422,8 +422,8 @@ class CaseExpressionTests(TestCase):
f_plus_1=F('integer') + 1,
).filter(
integer=Case(
When(integer2=F('integer'), then=Value(2)),
When(integer2=F('f_plus_1'), then=Value(3)),
When(integer2=F('integer'), then=2),
When(integer2=F('f_plus_1'), then=3),
output_field=models.IntegerField(),
),
).order_by('pk'),
@ -437,9 +437,9 @@ class CaseExpressionTests(TestCase):
f_plus_1=F('integer') + 1,
).filter(
integer2=Case(
When(f_plus_1=3, then=Value(3)),
When(f_plus_1=4, then=Value(4)),
default=Value(1),
When(f_plus_1=3, then=3),
When(f_plus_1=4, then=4),
default=1,
output_field=models.IntegerField(),
),
).order_by('pk'),
@ -469,8 +469,8 @@ class CaseExpressionTests(TestCase):
max=Max('fk_rel__integer'),
).filter(
integer=Case(
When(integer2=F('min'), then=Value(2)),
When(integer2=F('max'), then=Value(3)),
When(integer2=F('min'), then=2),
When(integer2=F('max'), then=3),
),
).order_by('pk'),
[(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)],
@ -483,8 +483,8 @@ class CaseExpressionTests(TestCase):
max=Max('fk_rel__integer'),
).filter(
integer=Case(
When(max=3, then=Value(2)),
When(max=4, then=Value(3)),
When(max=3, then=2),
When(max=4, then=3),
),
).order_by('pk'),
[(2, 3, 3), (3, 4, 4), (2, 2, 3), (3, 4, 4), (3, 3, 4)],
@ -508,8 +508,8 @@ class CaseExpressionTests(TestCase):
def test_update_without_default(self):
CaseTestModel.objects.update(
integer2=Case(
When(integer=1, then=Value(1)),
When(integer=2, then=Value(2)),
When(integer=1, then=1),
When(integer=2, then=2),
),
)
self.assertQuerysetEqual(
@ -549,8 +549,8 @@ class CaseExpressionTests(TestCase):
with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'):
CaseTestModel.objects.update(
integer=Case(
When(integer2=F('o2o_rel__integer') + 1, then=Value(2)),
When(integer2=F('o2o_rel__integer'), then=Value(3)),
When(integer2=F('o2o_rel__integer') + 1, then=2),
When(integer2=F('o2o_rel__integer'), then=3),
output_field=models.IntegerField(),
),
)
@ -570,8 +570,8 @@ class CaseExpressionTests(TestCase):
def test_update_big_integer(self):
CaseTestModel.objects.update(
big_integer=Case(
When(integer=1, then=Value(1)),
When(integer=2, then=Value(2)),
When(integer=1, then=1),
When(integer=2, then=2),
),
)
self.assertQuerysetEqual(
@ -599,9 +599,9 @@ class CaseExpressionTests(TestCase):
def test_update_boolean(self):
CaseTestModel.objects.update(
boolean=Case(
When(integer=1, then=Value(True)),
When(integer=2, then=Value(True)),
default=Value(False),
When(integer=1, then=True),
When(integer=2, then=True),
default=False,
),
)
self.assertQuerysetEqual(
@ -627,8 +627,8 @@ class CaseExpressionTests(TestCase):
def test_update_date(self):
CaseTestModel.objects.update(
date=Case(
When(integer=1, then=Value(date(2015, 1, 1))),
When(integer=2, then=Value(date(2015, 1, 2))),
When(integer=1, then=date(2015, 1, 1)),
When(integer=2, then=date(2015, 1, 2)),
),
)
self.assertQuerysetEqual(
@ -643,8 +643,8 @@ class CaseExpressionTests(TestCase):
def test_update_date_time(self):
CaseTestModel.objects.update(
date_time=Case(
When(integer=1, then=Value(datetime(2015, 1, 1))),
When(integer=2, then=Value(datetime(2015, 1, 2))),
When(integer=1, then=datetime(2015, 1, 1)),
When(integer=2, then=datetime(2015, 1, 2)),
),
)
self.assertQuerysetEqual(
@ -659,8 +659,8 @@ class CaseExpressionTests(TestCase):
def test_update_decimal(self):
CaseTestModel.objects.update(
decimal=Case(
When(integer=1, then=Value(Decimal('1.1'))),
When(integer=2, then=Value(Decimal('2.2'))),
When(integer=1, then=Decimal('1.1')),
When(integer=2, then=Decimal('2.2')),
),
)
self.assertQuerysetEqual(
@ -728,8 +728,8 @@ class CaseExpressionTests(TestCase):
def test_update_float(self):
CaseTestModel.objects.update(
float=Case(
When(integer=1, then=Value(1.1)),
When(integer=2, then=Value(2.2)),
When(integer=1, then=1.1),
When(integer=2, then=2.2),
),
)
self.assertQuerysetEqual(
@ -770,8 +770,8 @@ class CaseExpressionTests(TestCase):
def test_update_null_boolean(self):
CaseTestModel.objects.update(
null_boolean=Case(
When(integer=1, then=Value(True)),
When(integer=2, then=Value(False)),
When(integer=1, then=True),
When(integer=2, then=False),
),
)
self.assertQuerysetEqual(
@ -783,8 +783,8 @@ class CaseExpressionTests(TestCase):
def test_update_positive_integer(self):
CaseTestModel.objects.update(
positive_integer=Case(
When(integer=1, then=Value(1)),
When(integer=2, then=Value(2)),
When(integer=1, then=1),
When(integer=2, then=2),
),
)
self.assertQuerysetEqual(
@ -796,8 +796,8 @@ class CaseExpressionTests(TestCase):
def test_update_positive_small_integer(self):
CaseTestModel.objects.update(
positive_small_integer=Case(
When(integer=1, then=Value(1)),
When(integer=2, then=Value(2)),
When(integer=1, then=1),
When(integer=2, then=2),
),
)
self.assertQuerysetEqual(
@ -823,8 +823,8 @@ class CaseExpressionTests(TestCase):
def test_update_small_integer(self):
CaseTestModel.objects.update(
small_integer=Case(
When(integer=1, then=Value(1)),
When(integer=2, then=Value(2)),
When(integer=1, then=1),
When(integer=2, then=2),
),
)
self.assertQuerysetEqual(
@ -921,8 +921,8 @@ class CaseExpressionTests(TestCase):
CaseTestModel.objects.update(
fk=Case(
When(integer=1, then=Value(obj1.pk)),
When(integer=2, then=Value(obj2.pk)),
When(integer=1, then=obj1.pk),
When(integer=2, then=obj2.pk),
),
)
self.assertQuerysetEqual(
@ -1065,15 +1065,15 @@ class CaseDocumentationExamples(TestCase):
self.assertEqual(
Client.objects.aggregate(
regular=models.Sum(Case(
When(account_type=Client.REGULAR, then=Value(1)),
When(account_type=Client.REGULAR, then=1),
output_field=models.IntegerField(),
)),
gold=models.Sum(Case(
When(account_type=Client.GOLD, then=Value(1)),
When(account_type=Client.GOLD, then=1),
output_field=models.IntegerField(),
)),
platinum=models.Sum(Case(
When(account_type=Client.PLATINUM, then=Value(1)),
When(account_type=Client.PLATINUM, then=1),
output_field=models.IntegerField(),
)),
),