Fixed #30446 -- Resolved Value.output_field for stdlib types.

This required implementing a limited form of dynamic dispatch to combine
expressions with numerical output. Refs #26355 should eventually provide
a better interface for that.
This commit is contained in:
Simon Charette 2019-05-12 17:17:47 -04:00 committed by Mariusz Felisiak
parent d08e6f55e3
commit 1e38f1191d
10 changed files with 122 additions and 39 deletions

View File

@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin:
is not acceptable by the GIS functions expecting numeric values.
"""
def as_sqlite(self, compiler, connection, **extra_context):
for expr in self.get_source_expressions():
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
expr.value = float(expr.value)
return super().as_sql(compiler, connection, **extra_context)
copy = self.copy()
copy.set_source_expressions([
Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal)
else expr
for expr in copy.get_source_expressions()
])
return copy.as_sql(compiler, connection, **extra_context)
class OracleToleranceMixin:

View File

@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup):
def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup.
if isinstance(self.rhs, datetime.date):
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
value = models.Value(self.rhs, output_field=output_field)
value = models.Value(self.rhs)
self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection)

View File

@ -1,7 +1,9 @@
import copy
import datetime
import functools
import inspect
from decimal import Decimal
from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError
from django.db import NotSupportedError, connection
@ -56,12 +58,7 @@ class Combinable:
def _combine(self, other, connector, reversed):
if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression
output_field = (
fields.DurationField()
if isinstance(other, datetime.timedelta) else
None
)
other = Value(other, output_field=output_field)
other = Value(other)
if reversed:
return CombinedExpression(other, connector, self)
@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
pass
_connector_combinators = {
connector: [
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField),
]
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
}
@functools.lru_cache(maxsize=128)
def _resolve_combined_type(connector, lhs_type, rhs_type):
combinators = _connector_combinators.get(connector, ())
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
return combined_type
class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__(self, lhs, connector, rhs, output_field=None):
@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs
def _resolve_output_field(self):
try:
return super()._resolve_output_field()
except FieldError:
combined_type = _resolve_combined_type(
self.connector,
type(self.lhs.output_field),
type(self.rhs.output_field),
)
if combined_type is None:
raise
return combined_type()
def as_sql(self, compiler, connection):
expressions = []
expression_params = []
@ -721,6 +750,30 @@ class Value(Expression):
def get_group_by_cols(self, alias=None):
return []
def _resolve_output_field(self):
if isinstance(self.value, str):
return fields.CharField()
if isinstance(self.value, bool):
return fields.BooleanField()
if isinstance(self.value, int):
return fields.IntegerField()
if isinstance(self.value, float):
return fields.FloatField()
if isinstance(self.value, datetime.datetime):
return fields.DateTimeField()
if isinstance(self.value, datetime.date):
return fields.DateField()
if isinstance(self.value, datetime.time):
return fields.TimeField()
if isinstance(self.value, datetime.timedelta):
return fields.DurationField()
if isinstance(self.value, Decimal):
return fields.DecimalField()
if isinstance(self.value, bytes):
return fields.BinaryField()
if isinstance(self.value, UUID):
return fields.UUIDField()
class RawSQL(Expression):
def __init__(self, sql, params, output_field=None):
@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
copy.expression = Case(
When(self.expression, then=True),
default=False,
output_field=fields.BooleanField(),
)
return copy.as_sql(compiler, connection)
return self.as_sql(compiler, connection)

View File

@ -6,7 +6,7 @@ from copy import copy
from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Exists, Func, Value, When
from django.db.models.fields import (
BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
CharField, DateTimeField, Field, IntegerField, UUIDField,
)
from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet
@ -123,7 +123,7 @@ class Lookup:
exprs = []
for expr in (self.lhs, self.rhs):
if isinstance(expr, Exists):
expr = Case(When(expr, then=True), default=False, output_field=BooleanField())
expr = Case(When(expr, then=True), default=False)
wrapped = True
exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self

View File

@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like
after it's retrieved from the database. Usually no arguments are needed when
instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
output value.
output value. If no ``output_field`` is specified it will be tentatively
inferred from the :py:class:`type` of the provided ``value``, if possible. For
example, passing an instance of :py:class:`datetime.datetime` as ``value``
would default ``output_field`` to :class:`~django.db.models.DateTimeField`.
.. versionchanged:: 3.2
Support for inferring a default ``output_field`` from the type of ``value``
was added.
``ExpressionWrapper()`` expressions
-----------------------------------

View File

@ -233,6 +233,15 @@ Models
* The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed
on MySQL 8.0.1+.
* :class:`Value() <django.db.models.Value>` expression now
automatically resolves its ``output_field`` to the appropriate
:class:`Field <django.db.models.Field>` subclass based on the type of
it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`,
:py:class:`float`, :py:class:`int`, :py:class:`str`,
:py:class:`datetime.date`, :py:class:`datetime.datetime`,
:py:class:`datetime.time`, :py:class:`datetime.timedelta`,
:py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -848,10 +848,6 @@ class AggregateTestCase(TestCase):
book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first()
self.assertEqual(book.val, 2)
def test_missing_output_field_raises_error(self):
with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'):
Book.objects.annotate(val=Max(2)).first()
def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
@ -893,7 +889,7 @@ class AggregateTestCase(TestCase):
def test_combine_different_types(self):
msg = (
'Expression contains mixed types: FloatField, IntegerField. '
'Expression contains mixed types: FloatField, DecimalField. '
'You must set output_field.'
)
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))

View File

@ -388,7 +388,7 @@ class AggregationTests(TestCase):
)
def test_annotated_conditional_aggregate(self):
annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75)
annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75'))
self.assertAlmostEqual(
annotated_qs.aggregate(test=Avg(Case(
When(pages__lt=400, then='discount_price'),

View File

@ -3,15 +3,17 @@ import pickle
import unittest
import uuid
from copy import deepcopy
from decimal import Decimal
from unittest import mock
from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import (
Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField,
DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F,
Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev,
Subquery, Sum, TimeField, UUIDField, Value, Variance, When,
Avg, BinaryField, BooleanField, Case, CharField, Count, DateField,
DateTimeField, DecimalField, DurationField, Exists, Expression,
ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField,
UUIDField, Value, Variance, When,
)
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
from django.db.models.functions import (
@ -1711,6 +1713,30 @@ class ValueTests(TestCase):
value = Value('foo', output_field=CharField())
self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo']))
def test_resolve_output_field(self):
value_types = [
('str', CharField),
(True, BooleanField),
(42, IntegerField),
(3.14, FloatField),
(datetime.date(2019, 5, 15), DateField),
(datetime.datetime(2019, 5, 15), DateTimeField),
(datetime.time(3, 16), TimeField),
(datetime.timedelta(1), DurationField),
(Decimal('3.14'), DecimalField),
(b'', BinaryField),
(uuid.uuid4(), UUIDField),
]
for value, ouput_field_type in value_types:
with self.subTest(type=type(value)):
expr = Value(value)
self.assertIsInstance(expr.output_field, ouput_field_type)
def test_resolve_output_field_failure(self):
msg = 'Cannot resolve expression type, unknown output_field'
with self.assertRaisesMessage(FieldError, msg):
Value(object()).output_field
class FieldTransformTests(TestCase):
@ -1848,7 +1874,9 @@ class ExpressionWrapperTests(SimpleTestCase):
self.assertEqual(expr.get_group_by_cols(alias=None), [])
def test_non_empty_group_by(self):
expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField())
value = Value('f')
value.output_field = None
expr = ExpressionWrapper(Lower(value), output_field=IntegerField())
group_by_cols = expr.get_group_by_cols(alias=None)
self.assertEqual(group_by_cols, [expr.expression])
self.assertEqual(group_by_cols[0].output_field, expr.output_field)

View File

@ -1,7 +1,6 @@
from datetime import datetime
from operator import attrgetter
from django.core.exceptions import FieldError
from django.db.models import (
CharField, DateTimeField, F, Max, OuterRef, Subquery, Value,
)
@ -439,17 +438,6 @@ class OrderingTests(TestCase):
qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline')
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])
def test_order_by_constant_value_without_output_field(self):
msg = 'Cannot resolve expression type, unknown output_field'
qs = Article.objects.annotate(constant=Value('1')).order_by('constant')
for ordered_qs in (
qs,
qs.values('headline'),
Article.objects.order_by(Value('1')),
):
with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg):
ordered_qs.first()
def test_related_ordering_duplicate_table_reference(self):
"""
An ordering referencing a model with an ordering referencing a model