Fixed #24485 -- Allowed combined expressions to set output_field

This commit is contained in:
Josh Smeaton 2015-03-19 14:07:53 +11:00
parent 3a1886d111
commit e654123f7f
6 changed files with 101 additions and 19 deletions

View File

@ -4,7 +4,9 @@ import warnings
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
from django.db.models.query import Q, QuerySet, Prefetch # NOQA
from django.db.models.expressions import Expression, F, Value, Func, Case, When # NOQA
from django.db.models.expressions import ( # NOQA
Expression, ExpressionWrapper, F, Value, Func, Case, When,
)
from django.db.models.manager import Manager # NOQA
from django.db.models.base import Model # NOQA
from django.db.models.aggregates import * # NOQA

View File

@ -126,12 +126,12 @@ class BaseExpression(object):
# aggregate specific fields
is_summary = False
def get_db_converters(self, connection):
return [self.convert_value] + self.output_field.get_db_converters(connection)
def __init__(self, output_field=None):
self._output_field = output_field
def get_db_converters(self, connection):
return [self.convert_value] + self.output_field.get_db_converters(connection)
def get_source_expressions(self):
return []
@ -656,6 +656,29 @@ class Ref(Expression):
return [self]
class ExpressionWrapper(Expression):
"""
An expression that can wrap another expression so that it can provide
extra context to the inner expression, such as the output_field.
"""
def __init__(self, expression, output_field):
super(ExpressionWrapper, self).__init__(output_field=output_field)
self.expression = expression
def set_source_expressions(self, exprs):
self.expression = exprs[0]
def get_source_expressions(self):
return [self.expression]
def as_sql(self, compiler, connection):
return self.expression.as_sql(compiler, connection)
def __repr__(self):
return "{}({})".format(self.__class__.__name__, self.expression)
class When(Expression):
template = 'WHEN %(condition)s THEN %(result)s'

View File

@ -165,6 +165,27 @@ values, rather than on Python values.
This is documented in :ref:`using F() expressions in queries
<using-f-expressions-in-filters>`.
.. _using-f-with-annotations:
Using ``F()`` with annotations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``F()`` can be used to create dynamic fields on your models by combining
different fields with arithmetic::
company = Company.objects.annotate(
chairs_needed=F('num_employees') - F('num_chairs'))
If the fields that you're combining are of different types you'll need
to tell Django what kind of field will be returned. Since ``F()`` does not
directly support ``output_field`` you will need to wrap the expression with
:class:`ExpressionWrapper`::
from django.db.models import DateTimeField, ExpressionWrapper, F
Ticket.objects.annotate(
expires=ExpressionWrapper(
F('active_at') + F('duration'), output_field=DateTimeField()))
.. _func-expressions:
@ -278,17 +299,6 @@ should define the desired ``output_field``. For example, adding an
``IntegerField()`` and a ``FloatField()`` together should probably have
``output_field=FloatField()`` defined.
.. note::
When you need to define the ``output_field`` for ``F`` expression
arithmetic between different types, it's necessary to surround the
expression in another expression::
from django.db.models import DateTimeField, Expression, F
Race.objects.annotate(finish=Expression(
F('start') + F('duration'), output_field=DateTimeField()))
.. versionchanged:: 1.8
``output_field`` is a new parameter.
@ -347,6 +357,19 @@ instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
output value.
``ExpressionWrapper()`` expressions
-----------------------------------
.. class:: ExpressionWrapper(expression, output_field)
.. versionadded:: 1.8
``ExpressionWrapper`` simply surrounds another expression and provides access
to properties, such as ``output_field``, that may not be available on other
expressions. ``ExpressionWrapper`` is necessary when using arithmetic on
``F()`` expressions with different types as described in
:ref:`using-f-with-annotations`.
Conditional expressions
-----------------------

View File

@ -84,3 +84,12 @@ class Company(models.Model):
return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)'
% (self.name, self.motto, self.ticker_name, self.description)
)
@python_2_unicode_compatible
class Ticket(models.Model):
active_at = models.DateTimeField()
duration = models.DurationField()
def __str__(self):
return '{} - {}'.format(self.active_at, self.duration)

View File

@ -5,12 +5,15 @@ from decimal import Decimal
from django.core.exceptions import FieldDoesNotExist, FieldError
from django.db.models import (
F, BooleanField, CharField, Count, Func, IntegerField, Sum, Value,
F, BooleanField, CharField, Count, DateTimeField, ExpressionWrapper, Func,
IntegerField, Sum, Value,
)
from django.test import TestCase
from django.utils import six
from .models import Author, Book, Company, DepartmentStore, Employee, Store
from .models import (
Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket,
)
def cxOracle_513_py3_bug(func):
@ -52,6 +55,24 @@ class NonAggregateAnnotationTestCase(TestCase):
for book in books:
self.assertEqual(book.num_awards, book.publisher.num_awards)
def test_mixed_type_annotation_date_interval(self):
active = datetime.datetime(2015, 3, 20, 14, 0, 0)
duration = datetime.timedelta(hours=1)
expires = datetime.datetime(2015, 3, 20, 14, 0, 0) + duration
Ticket.objects.create(active_at=active, duration=duration)
t = Ticket.objects.annotate(
expires=ExpressionWrapper(F('active_at') + F('duration'), output_field=DateTimeField())
).first()
self.assertEqual(t.expires, expires)
def test_mixed_type_annotation_numbers(self):
test = self.b1
b = Book.objects.annotate(
combined=ExpressionWrapper(F('pages') + F('rating'), output_field=IntegerField())
).get(isbn=test.isbn)
combined = int(test.pages + test.rating)
self.assertEqual(b.combined, combined)
def test_annotate_with_aggregation(self):
books = Book.objects.annotate(
is_book=Value(1, output_field=IntegerField()),

View File

@ -11,8 +11,8 @@ from django.db.models.aggregates import (
Avg, Count, Max, Min, StdDev, Sum, Variance,
)
from django.db.models.expressions import (
F, Case, Col, Date, DateTime, Func, OrderBy, Random, RawSQL, Ref, Value,
When,
F, Case, Col, Date, DateTime, ExpressionWrapper, Func, OrderBy, Random,
RawSQL, Ref, Value, When,
)
from django.db.models.functions import (
Coalesce, Concat, Length, Lower, Substr, Upper,
@ -855,6 +855,10 @@ class ReprTests(TestCase):
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, %s)" % utc)
self.assertEqual(repr(F('published')), "F(published)")
self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>")
self.assertEqual(
repr(ExpressionWrapper(F('cost') + F('tax'), models.IntegerField())),
"ExpressionWrapper(F(cost) + F(tax))"
)
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
self.assertEqual(repr(Random()), "Random()")