From a6ea62aeafd4512f6d13aeda908f7622776a4537 Mon Sep 17 00:00:00 2001 From: Josh Smeaton Date: Wed, 11 Feb 2015 16:38:02 +1100 Subject: [PATCH] [1.8.x] Refs #14030 -- Improved expression support for python values Backport of e2d6e14662d780383e18066a3182155fb5b7747b from master --- django/db/backends/base/features.py | 3 +- django/db/models/expressions.py | 26 ++--- docs/ref/models/conditional-expressions.txt | 12 +- docs/ref/models/expressions.txt | 4 + tests/aggregation/tests.py | 10 +- tests/db_functions/tests.py | 2 +- tests/expressions_case/tests.py | 120 ++++++++++---------- 7 files changed, 89 insertions(+), 88 deletions(-) diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 4b4d5c6d759..e1c42224105 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -1,5 +1,4 @@ from django.db.models.aggregates import StdDev -from django.db.models.expressions import Value from django.db.utils import ProgrammingError from django.utils.functional import cached_property @@ -229,7 +228,7 @@ class BaseDatabaseFeatures(object): def supports_stddev(self): """Confirm support for STDDEV and related stats functions.""" try: - self.connection.ops.check_expression_support(StdDev(Value(1))) + self.connection.ops.check_expression_support(StdDev(1)) return True except NotImplementedError: return False diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index c15b095bc21..8e5c4aa533d 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -7,7 +7,7 @@ from django.db.backends import utils as backend_utils from django.db.models import fields from django.db.models.constants import LOOKUP_SEP from django.db.models.query_utils import Q, refs_aggregate -from django.utils import timezone +from django.utils import six, timezone from django.utils.functional import cached_property @@ -138,6 +138,13 @@ class BaseExpression(object): def set_source_expressions(self, exprs): assert len(exprs) == 0 + def _parse_expressions(self, *expressions): + return [ + arg if hasattr(arg, 'resolve_expression') else ( + F(arg) if isinstance(arg, six.string_types) else Value(arg) + ) for arg in expressions + ] + def as_sql(self, compiler, connection): """ Responsible for returning a (sql, [params]) tuple to be included @@ -466,12 +473,6 @@ class Func(ExpressionNode): def set_source_expressions(self, exprs): self.source_expressions = exprs - def _parse_expressions(self, *expressions): - return [ - arg if hasattr(arg, 'resolve_expression') else F(arg) - for arg in expressions - ] - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): c = self.copy() c.is_summary = summarize @@ -639,14 +640,14 @@ class Ref(ExpressionNode): class When(ExpressionNode): template = 'WHEN %(condition)s THEN %(result)s' - def __init__(self, condition=None, then=Value(None), **lookups): + def __init__(self, condition=None, then=None, **lookups): if lookups and condition is None: condition, lookups = Q(**lookups), None if condition is None or not isinstance(condition, Q) or lookups: raise TypeError("__init__() takes either a Q object or lookups as keyword arguments") super(When, self).__init__(output_field=None) self.condition = condition - self.result = self._parse_expression(then) + self.result = self._parse_expressions(then)[0] def __str__(self): return "WHEN %r THEN %r" % (self.condition, self.result) @@ -664,9 +665,6 @@ class When(ExpressionNode): # We're only interested in the fields of the result expressions. return [self.result._output_field_or_none] - def _parse_expression(self, expression): - return expression if hasattr(expression, 'resolve_expression') else F(expression) - def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): c = self.copy() c.is_summary = summarize @@ -713,11 +711,11 @@ class Case(ExpressionNode): def __init__(self, *cases, **extra): if not all(isinstance(case, When) for case in cases): raise TypeError("Positional arguments must all be When objects.") - default = extra.pop('default', Value(None)) + default = extra.pop('default', None) output_field = extra.pop('output_field', None) super(Case, self).__init__(output_field) self.cases = list(cases) - self.default = default if hasattr(default, 'resolve_expression') else F(default) + self.default = self._parse_expressions(default)[0] def __str__(self): return "CASE %s, ELSE %r" % (', '.join(str(c) for c in self.cases), self.default) diff --git a/docs/ref/models/conditional-expressions.txt b/docs/ref/models/conditional-expressions.txt index 5692bc1968a..b62f8a5c350 100644 --- a/docs/ref/models/conditional-expressions.txt +++ b/docs/ref/models/conditional-expressions.txt @@ -39,7 +39,7 @@ We'll be using the following model in the subsequent examples:: When ---- -.. class:: When(condition=None, then=Value(None), **lookups) +.. class:: When(condition=None, then=None, **lookups) A ``When()`` object is used to encapsulate a condition and its result for use in the conditional expression. Using a ``When()`` object is similar to using @@ -73,8 +73,8 @@ Keep in mind that each of these values can be an expression. resolved in two ways:: >>> from django.db.models import Value - >>> When(then__exact=0, then=Value(1)) - >>> When(Q(then=0), then=Value(1)) + >>> When(then__exact=0, then=1) + >>> When(Q(then=0), then=1) Case ---- @@ -197,15 +197,15 @@ What if we want to find out how many clients there are for each >>> from django.db.models import IntegerField, Sum >>> Client.objects.aggregate( ... regular=Sum( - ... Case(When(account_type=Client.REGULAR, then=Value(1)), + ... Case(When(account_type=Client.REGULAR, then=1), ... output_field=IntegerField()) ... ), ... gold=Sum( - ... Case(When(account_type=Client.GOLD, then=Value(1)), + ... Case(When(account_type=Client.GOLD, then=1), ... output_field=IntegerField()) ... ), ... platinum=Sum( - ... Case(When(account_type=Client.PLATINUM, then=Value(1)), + ... Case(When(account_type=Client.PLATINUM, then=1), ... output_field=IntegerField()) ... ) ... ) diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index b6b40002788..f170af8510b 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -221,6 +221,10 @@ function will be applied to. The expressions will be converted to strings, joined together with ``arg_joiner``, and then interpolated into the ``template`` as the ``expressions`` placeholder. +Positional arguments can be expressions or Python values. Strings are +assumed to be column references and will be wrapped in ``F()`` expressions +while other values will be wrapped in ``Value()`` expressions. + The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated into the ``template`` attribute. Note that the keywords ``function`` and ``template`` can be used to replace the ``function`` and ``template`` diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index 9282c7e10cc..bc3929d98e7 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -8,7 +8,7 @@ from django.core.exceptions import FieldError from django.db import connection from django.db.models import ( F, Aggregate, Avg, Count, DecimalField, FloatField, Func, IntegerField, - Max, Min, Sum, Value, + Max, Min, Sum, ) from django.test import TestCase, ignore_warnings from django.test.utils import Approximate, CaptureQueriesContext @@ -706,14 +706,14 @@ class ComplexAggregateTestCase(TestCase): Book.objects.aggregate(fail=F('price')) def test_nonfield_annotation(self): - book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField())))[0] + book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0] self.assertEqual(book.val, 2) - book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField()))[0] + book = Book.objects.annotate(val=Max(2, output_field=IntegerField()))[0] self.assertEqual(book.val, 2) def test_missing_output_field_raises_error(self): with six.assertRaisesRegex(self, FieldError, 'Cannot resolve expression type, unknown output_field'): - Book.objects.annotate(val=Max(Value(2)))[0] + Book.objects.annotate(val=Max(2))[0] def test_annotation_expressions(self): authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') @@ -772,7 +772,7 @@ class ComplexAggregateTestCase(TestCase): with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'): Author.objects.aggregate(Sum('age') / Count('age')) with six.assertRaisesRegex(self, TypeError, 'Complex aggregates require an alias'): - Author.objects.aggregate(Sum(Value(1))) + Author.objects.aggregate(Sum(1)) def test_aggregate_over_complex_annotation(self): qs = Author.objects.annotate( diff --git a/tests/db_functions/tests.py b/tests/db_functions/tests.py index 9e7d27899a0..f07e8770f0a 100644 --- a/tests/db_functions/tests.py +++ b/tests/db_functions/tests.py @@ -281,7 +281,7 @@ class FunctionTests(TestCase): def test_substr_with_expressions(self): Author.objects.create(name='John Smith', alias='smithj') Author.objects.create(name='Rhonda') - authors = Author.objects.annotate(name_part=Substr('name', V(5), V(3))) + authors = Author.objects.annotate(name_part=Substr('name', 5, 3)) self.assertQuerysetEqual( authors.order_by('name'), [ ' Sm', diff --git a/tests/expressions_case/tests.py b/tests/expressions_case/tests.py index 176c969c298..b1b4c779480 100644 --- a/tests/expressions_case/tests.py +++ b/tests/expressions_case/tests.py @@ -78,8 +78,8 @@ class CaseExpressionTests(TestCase): def test_annotate_without_default(self): self.assertQuerysetEqual( CaseTestModel.objects.annotate(test=Case( - When(integer=1, then=Value(1)), - When(integer=2, then=Value(2)), + When(integer=1, then=1), + When(integer=2, then=2), output_field=models.IntegerField(), )).order_by('pk'), [(1, 1), (2, 2), (3, None), (2, 2), (3, None), (3, None), (4, None)], @@ -244,9 +244,9 @@ class CaseExpressionTests(TestCase): self.assertQuerysetEqual( CaseTestModel.objects.annotate( test=Case( - When(integer=1, then=Value(2)), - When(integer=2, then=Value(1)), - default=Value(3), + When(integer=1, then=2), + When(integer=2, then=1), + default=3, output_field=models.IntegerField(), ) + 1, ).order_by('pk'), @@ -278,19 +278,19 @@ class CaseExpressionTests(TestCase): self.assertEqual( CaseTestModel.objects.aggregate( one=models.Sum(Case( - When(integer=1, then=Value(1)), + When(integer=1, then=1), output_field=models.IntegerField(), )), two=models.Sum(Case( - When(integer=2, then=Value(1)), + When(integer=2, then=1), output_field=models.IntegerField(), )), three=models.Sum(Case( - When(integer=3, then=Value(1)), + When(integer=3, then=1), output_field=models.IntegerField(), )), four=models.Sum(Case( - When(integer=4, then=Value(1)), + When(integer=4, then=1), output_field=models.IntegerField(), )), ), @@ -311,11 +311,11 @@ class CaseExpressionTests(TestCase): self.assertEqual( CaseTestModel.objects.aggregate( equal=models.Sum(Case( - When(integer2=F('integer'), then=Value(1)), + When(integer2=F('integer'), then=1), output_field=models.IntegerField(), )), plus_one=models.Sum(Case( - When(integer2=F('integer') + 1, then=Value(1)), + When(integer2=F('integer') + 1, then=1), output_field=models.IntegerField(), )), ), @@ -325,9 +325,9 @@ class CaseExpressionTests(TestCase): def test_filter(self): self.assertQuerysetEqual( CaseTestModel.objects.filter(integer2=Case( - When(integer=2, then=Value(3)), - When(integer=3, then=Value(4)), - default=Value(1), + When(integer=2, then=3), + When(integer=3, then=4), + default=1, output_field=models.IntegerField(), )).order_by('pk'), [(1, 1), (2, 3), (3, 4), (3, 4)], @@ -337,8 +337,8 @@ class CaseExpressionTests(TestCase): def test_filter_without_default(self): self.assertQuerysetEqual( CaseTestModel.objects.filter(integer2=Case( - When(integer=2, then=Value(3)), - When(integer=3, then=Value(4)), + When(integer=2, then=3), + When(integer=3, then=4), output_field=models.IntegerField(), )).order_by('pk'), [(2, 3), (3, 4), (3, 4)], @@ -381,8 +381,8 @@ class CaseExpressionTests(TestCase): def test_filter_with_join_in_condition(self): self.assertQuerysetEqual( CaseTestModel.objects.filter(integer=Case( - When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), - When(integer2=F('o2o_rel__integer'), then=Value(3)), + When(integer2=F('o2o_rel__integer') + 1, then=2), + When(integer2=F('o2o_rel__integer'), then=3), output_field=models.IntegerField(), )).order_by('pk'), [(2, 3), (3, 3)], @@ -392,9 +392,9 @@ class CaseExpressionTests(TestCase): def test_filter_with_join_in_predicate(self): self.assertQuerysetEqual( CaseTestModel.objects.filter(integer2=Case( - When(o2o_rel__integer=1, then=Value(1)), - When(o2o_rel__integer=2, then=Value(3)), - When(o2o_rel__integer=3, then=Value(4)), + When(o2o_rel__integer=1, then=1), + When(o2o_rel__integer=2, then=3), + When(o2o_rel__integer=3, then=4), output_field=models.IntegerField(), )).order_by('pk'), [(1, 1), (2, 3), (3, 4), (3, 4)], @@ -422,8 +422,8 @@ class CaseExpressionTests(TestCase): f_plus_1=F('integer') + 1, ).filter( integer=Case( - When(integer2=F('integer'), then=Value(2)), - When(integer2=F('f_plus_1'), then=Value(3)), + When(integer2=F('integer'), then=2), + When(integer2=F('f_plus_1'), then=3), output_field=models.IntegerField(), ), ).order_by('pk'), @@ -437,9 +437,9 @@ class CaseExpressionTests(TestCase): f_plus_1=F('integer') + 1, ).filter( integer2=Case( - When(f_plus_1=3, then=Value(3)), - When(f_plus_1=4, then=Value(4)), - default=Value(1), + When(f_plus_1=3, then=3), + When(f_plus_1=4, then=4), + default=1, output_field=models.IntegerField(), ), ).order_by('pk'), @@ -469,8 +469,8 @@ class CaseExpressionTests(TestCase): max=Max('fk_rel__integer'), ).filter( integer=Case( - When(integer2=F('min'), then=Value(2)), - When(integer2=F('max'), then=Value(3)), + When(integer2=F('min'), then=2), + When(integer2=F('max'), then=3), ), ).order_by('pk'), [(3, 4, 3, 4), (2, 2, 2, 3), (3, 4, 3, 4)], @@ -483,8 +483,8 @@ class CaseExpressionTests(TestCase): max=Max('fk_rel__integer'), ).filter( integer=Case( - When(max=3, then=Value(2)), - When(max=4, then=Value(3)), + When(max=3, then=2), + When(max=4, then=3), ), ).order_by('pk'), [(2, 3, 3), (3, 4, 4), (2, 2, 3), (3, 4, 4), (3, 3, 4)], @@ -508,8 +508,8 @@ class CaseExpressionTests(TestCase): def test_update_without_default(self): CaseTestModel.objects.update( integer2=Case( - When(integer=1, then=Value(1)), - When(integer=2, then=Value(2)), + When(integer=1, then=1), + When(integer=2, then=2), ), ) self.assertQuerysetEqual( @@ -549,8 +549,8 @@ class CaseExpressionTests(TestCase): with self.assertRaisesMessage(FieldError, 'Joined field references are not permitted in this query'): CaseTestModel.objects.update( integer=Case( - When(integer2=F('o2o_rel__integer') + 1, then=Value(2)), - When(integer2=F('o2o_rel__integer'), then=Value(3)), + When(integer2=F('o2o_rel__integer') + 1, then=2), + When(integer2=F('o2o_rel__integer'), then=3), output_field=models.IntegerField(), ), ) @@ -570,8 +570,8 @@ class CaseExpressionTests(TestCase): def test_update_big_integer(self): CaseTestModel.objects.update( big_integer=Case( - When(integer=1, then=Value(1)), - When(integer=2, then=Value(2)), + When(integer=1, then=1), + When(integer=2, then=2), ), ) self.assertQuerysetEqual( @@ -599,9 +599,9 @@ class CaseExpressionTests(TestCase): def test_update_boolean(self): CaseTestModel.objects.update( boolean=Case( - When(integer=1, then=Value(True)), - When(integer=2, then=Value(True)), - default=Value(False), + When(integer=1, then=True), + When(integer=2, then=True), + default=False, ), ) self.assertQuerysetEqual( @@ -627,8 +627,8 @@ class CaseExpressionTests(TestCase): def test_update_date(self): CaseTestModel.objects.update( date=Case( - When(integer=1, then=Value(date(2015, 1, 1))), - When(integer=2, then=Value(date(2015, 1, 2))), + When(integer=1, then=date(2015, 1, 1)), + When(integer=2, then=date(2015, 1, 2)), ), ) self.assertQuerysetEqual( @@ -643,8 +643,8 @@ class CaseExpressionTests(TestCase): def test_update_date_time(self): CaseTestModel.objects.update( date_time=Case( - When(integer=1, then=Value(datetime(2015, 1, 1))), - When(integer=2, then=Value(datetime(2015, 1, 2))), + When(integer=1, then=datetime(2015, 1, 1)), + When(integer=2, then=datetime(2015, 1, 2)), ), ) self.assertQuerysetEqual( @@ -659,8 +659,8 @@ class CaseExpressionTests(TestCase): def test_update_decimal(self): CaseTestModel.objects.update( decimal=Case( - When(integer=1, then=Value(Decimal('1.1'))), - When(integer=2, then=Value(Decimal('2.2'))), + When(integer=1, then=Decimal('1.1')), + When(integer=2, then=Decimal('2.2')), ), ) self.assertQuerysetEqual( @@ -728,8 +728,8 @@ class CaseExpressionTests(TestCase): def test_update_float(self): CaseTestModel.objects.update( float=Case( - When(integer=1, then=Value(1.1)), - When(integer=2, then=Value(2.2)), + When(integer=1, then=1.1), + When(integer=2, then=2.2), ), ) self.assertQuerysetEqual( @@ -785,8 +785,8 @@ class CaseExpressionTests(TestCase): def test_update_null_boolean(self): CaseTestModel.objects.update( null_boolean=Case( - When(integer=1, then=Value(True)), - When(integer=2, then=Value(False)), + When(integer=1, then=True), + When(integer=2, then=False), ), ) self.assertQuerysetEqual( @@ -798,8 +798,8 @@ class CaseExpressionTests(TestCase): def test_update_positive_integer(self): CaseTestModel.objects.update( positive_integer=Case( - When(integer=1, then=Value(1)), - When(integer=2, then=Value(2)), + When(integer=1, then=1), + When(integer=2, then=2), ), ) self.assertQuerysetEqual( @@ -811,8 +811,8 @@ class CaseExpressionTests(TestCase): def test_update_positive_small_integer(self): CaseTestModel.objects.update( positive_small_integer=Case( - When(integer=1, then=Value(1)), - When(integer=2, then=Value(2)), + When(integer=1, then=1), + When(integer=2, then=2), ), ) self.assertQuerysetEqual( @@ -838,8 +838,8 @@ class CaseExpressionTests(TestCase): def test_update_small_integer(self): CaseTestModel.objects.update( small_integer=Case( - When(integer=1, then=Value(1)), - When(integer=2, then=Value(2)), + When(integer=1, then=1), + When(integer=2, then=2), ), ) self.assertQuerysetEqual( @@ -936,8 +936,8 @@ class CaseExpressionTests(TestCase): CaseTestModel.objects.update( fk=Case( - When(integer=1, then=Value(obj1.pk)), - When(integer=2, then=Value(obj2.pk)), + When(integer=1, then=obj1.pk), + When(integer=2, then=obj2.pk), ), ) self.assertQuerysetEqual( @@ -1080,15 +1080,15 @@ class CaseDocumentationExamples(TestCase): self.assertEqual( Client.objects.aggregate( regular=models.Sum(Case( - When(account_type=Client.REGULAR, then=Value(1)), + When(account_type=Client.REGULAR, then=1), output_field=models.IntegerField(), )), gold=models.Sum(Case( - When(account_type=Client.GOLD, then=Value(1)), + When(account_type=Client.GOLD, then=1), output_field=models.IntegerField(), )), platinum=models.Sum(Case( - When(account_type=Client.PLATINUM, then=Value(1)), + When(account_type=Client.PLATINUM, then=1), output_field=models.IntegerField(), )), ),