mirror of https://github.com/django/django.git
Fixed #24485 -- Allowed combined expressions to set output_field
This commit is contained in:
parent
127b3873d0
commit
02a2943e4c
|
@ -2,7 +2,9 @@ from functools import wraps
|
|||
|
||||
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured # NOQA
|
||||
from django.db.models.query import Q, QuerySet, Prefetch # NOQA
|
||||
from django.db.models.expressions import Expression, F, Value, Func, Case, When # NOQA
|
||||
from django.db.models.expressions import ( # NOQA
|
||||
Expression, ExpressionWrapper, F, Value, Func, Case, When,
|
||||
)
|
||||
from django.db.models.manager import Manager # NOQA
|
||||
from django.db.models.base import Model # NOQA
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
|
|
|
@ -126,12 +126,12 @@ class BaseExpression(object):
|
|||
# aggregate specific fields
|
||||
is_summary = False
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||
|
||||
def __init__(self, output_field=None):
|
||||
self._output_field = output_field
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||
|
||||
def get_source_expressions(self):
|
||||
return []
|
||||
|
||||
|
@ -656,6 +656,29 @@ class Ref(Expression):
|
|||
return [self]
|
||||
|
||||
|
||||
class ExpressionWrapper(Expression):
|
||||
"""
|
||||
An expression that can wrap another expression so that it can provide
|
||||
extra context to the inner expression, such as the output_field.
|
||||
"""
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super(ExpressionWrapper, self).__init__(output_field=output_field)
|
||||
self.expression = expression
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
self.expression = exprs[0]
|
||||
|
||||
def get_source_expressions(self):
|
||||
return [self.expression]
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return self.expression.as_sql(compiler, connection)
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({})".format(self.__class__.__name__, self.expression)
|
||||
|
||||
|
||||
class When(Expression):
|
||||
template = 'WHEN %(condition)s THEN %(result)s'
|
||||
|
||||
|
|
|
@ -161,6 +161,27 @@ values, rather than on Python values.
|
|||
This is documented in :ref:`using F() expressions in queries
|
||||
<using-f-expressions-in-filters>`.
|
||||
|
||||
.. _using-f-with-annotations:
|
||||
|
||||
Using ``F()`` with annotations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
``F()`` can be used to create dynamic fields on your models by combining
|
||||
different fields with arithmetic::
|
||||
|
||||
company = Company.objects.annotate(
|
||||
chairs_needed=F('num_employees') - F('num_chairs'))
|
||||
|
||||
If the fields that you're combining are of different types you'll need
|
||||
to tell Django what kind of field will be returned. Since ``F()`` does not
|
||||
directly support ``output_field`` you will need to wrap the expression with
|
||||
:class:`ExpressionWrapper`::
|
||||
|
||||
from django.db.models import DateTimeField, ExpressionWrapper, F
|
||||
|
||||
Ticket.objects.annotate(
|
||||
expires=ExpressionWrapper(
|
||||
F('active_at') + F('duration'), output_field=DateTimeField()))
|
||||
|
||||
.. _func-expressions:
|
||||
|
||||
|
@ -274,17 +295,6 @@ should define the desired ``output_field``. For example, adding an
|
|||
``IntegerField()`` and a ``FloatField()`` together should probably have
|
||||
``output_field=FloatField()`` defined.
|
||||
|
||||
.. note::
|
||||
|
||||
When you need to define the ``output_field`` for ``F`` expression
|
||||
arithmetic between different types, it's necessary to surround the
|
||||
expression in another expression::
|
||||
|
||||
from django.db.models import DateTimeField, Expression, F
|
||||
|
||||
Race.objects.annotate(finish=Expression(
|
||||
F('start') + F('duration'), output_field=DateTimeField()))
|
||||
|
||||
.. versionchanged:: 1.8
|
||||
|
||||
``output_field`` is a new parameter.
|
||||
|
@ -343,6 +353,19 @@ instantiating the model field as any arguments relating to data validation
|
|||
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
|
||||
output value.
|
||||
|
||||
``ExpressionWrapper()`` expressions
|
||||
-----------------------------------
|
||||
|
||||
.. class:: ExpressionWrapper(expression, output_field)
|
||||
|
||||
.. versionadded:: 1.8
|
||||
|
||||
``ExpressionWrapper`` simply surrounds another expression and provides access
|
||||
to properties, such as ``output_field``, that may not be available on other
|
||||
expressions. ``ExpressionWrapper`` is necessary when using arithmetic on
|
||||
``F()`` expressions with different types as described in
|
||||
:ref:`using-f-with-annotations`.
|
||||
|
||||
Conditional expressions
|
||||
-----------------------
|
||||
|
||||
|
|
|
@ -84,3 +84,12 @@ class Company(models.Model):
|
|||
return ('Company(name=%s, motto=%s, ticker_name=%s, description=%s)'
|
||||
% (self.name, self.motto, self.ticker_name, self.description)
|
||||
)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Ticket(models.Model):
|
||||
active_at = models.DateTimeField()
|
||||
duration = models.DurationField()
|
||||
|
||||
def __str__(self):
|
||||
return '{} - {}'.format(self.active_at, self.duration)
|
||||
|
|
|
@ -5,13 +5,14 @@ from decimal import Decimal
|
|||
|
||||
from django.core.exceptions import FieldDoesNotExist, FieldError
|
||||
from django.db.models import (
|
||||
F, BooleanField, CharField, Count, Func, IntegerField, Sum, Value,
|
||||
F, BooleanField, CharField, Count, DateTimeField, ExpressionWrapper, Func,
|
||||
IntegerField, Sum, Value,
|
||||
)
|
||||
from django.test import TestCase
|
||||
from django.utils import six
|
||||
|
||||
from .models import (
|
||||
Author, Book, Company, DepartmentStore, Employee, Publisher, Store,
|
||||
Author, Book, Company, DepartmentStore, Employee, Publisher, Store, Ticket,
|
||||
)
|
||||
|
||||
|
||||
|
@ -135,6 +136,24 @@ class NonAggregateAnnotationTestCase(TestCase):
|
|||
for book in books:
|
||||
self.assertEqual(book.num_awards, book.publisher.num_awards)
|
||||
|
||||
def test_mixed_type_annotation_date_interval(self):
|
||||
active = datetime.datetime(2015, 3, 20, 14, 0, 0)
|
||||
duration = datetime.timedelta(hours=1)
|
||||
expires = datetime.datetime(2015, 3, 20, 14, 0, 0) + duration
|
||||
Ticket.objects.create(active_at=active, duration=duration)
|
||||
t = Ticket.objects.annotate(
|
||||
expires=ExpressionWrapper(F('active_at') + F('duration'), output_field=DateTimeField())
|
||||
).first()
|
||||
self.assertEqual(t.expires, expires)
|
||||
|
||||
def test_mixed_type_annotation_numbers(self):
|
||||
test = self.b1
|
||||
b = Book.objects.annotate(
|
||||
combined=ExpressionWrapper(F('pages') + F('rating'), output_field=IntegerField())
|
||||
).get(isbn=test.isbn)
|
||||
combined = int(test.pages + test.rating)
|
||||
self.assertEqual(b.combined, combined)
|
||||
|
||||
def test_annotate_with_aggregation(self):
|
||||
books = Book.objects.annotate(
|
||||
is_book=Value(1, output_field=IntegerField()),
|
||||
|
|
|
@ -11,8 +11,8 @@ from django.db.models.aggregates import (
|
|||
Avg, Count, Max, Min, StdDev, Sum, Variance,
|
||||
)
|
||||
from django.db.models.expressions import (
|
||||
F, Case, Col, Date, DateTime, Func, OrderBy, Random, RawSQL, Ref, Value,
|
||||
When,
|
||||
F, Case, Col, Date, DateTime, ExpressionWrapper, Func, OrderBy, Random,
|
||||
RawSQL, Ref, Value, When,
|
||||
)
|
||||
from django.db.models.functions import (
|
||||
Coalesce, Concat, Length, Lower, Substr, Upper,
|
||||
|
@ -855,6 +855,10 @@ class ReprTests(TestCase):
|
|||
self.assertEqual(repr(DateTime('published', 'exact', utc)), "DateTime(published, exact, %s)" % utc)
|
||||
self.assertEqual(repr(F('published')), "F(published)")
|
||||
self.assertEqual(repr(F('cost') + F('tax')), "<CombinedExpression: F(cost) + F(tax)>")
|
||||
self.assertEqual(
|
||||
repr(ExpressionWrapper(F('cost') + F('tax'), models.IntegerField())),
|
||||
"ExpressionWrapper(F(cost) + F(tax))"
|
||||
)
|
||||
self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)")
|
||||
self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)')
|
||||
self.assertEqual(repr(Random()), "Random()")
|
||||
|
|
Loading…
Reference in New Issue