Fixed #30446 -- Resolved Value.output_field for stdlib types.

This required implementing a limited form of dynamic dispatch to combine
expressions with numerical output. Refs #26355 should eventually provide
a better interface for that.
This commit is contained in:
Simon Charette 2019-05-12 17:17:47 -04:00 committed by Mariusz Felisiak
parent d08e6f55e3
commit 1e38f1191d
10 changed files with 122 additions and 39 deletions

View File

@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin:
is not acceptable by the GIS functions expecting numeric values. is not acceptable by the GIS functions expecting numeric values.
""" """
def as_sqlite(self, compiler, connection, **extra_context): def as_sqlite(self, compiler, connection, **extra_context):
for expr in self.get_source_expressions(): copy = self.copy()
if hasattr(expr, 'value') and isinstance(expr.value, Decimal): copy.set_source_expressions([
expr.value = float(expr.value) Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal)
return super().as_sql(compiler, connection, **extra_context) else expr
for expr in copy.get_source_expressions()
])
return copy.as_sql(compiler, connection, **extra_context)
class OracleToleranceMixin: class OracleToleranceMixin:

View File

@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup):
def process_rhs(self, compiler, connection): def process_rhs(self, compiler, connection):
# Transform rhs value for db lookup. # Transform rhs value for db lookup.
if isinstance(self.rhs, datetime.date): if isinstance(self.rhs, datetime.date):
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField() value = models.Value(self.rhs)
value = models.Value(self.rhs, output_field=output_field)
self.rhs = value.resolve_expression(compiler.query) self.rhs = value.resolve_expression(compiler.query)
return super().process_rhs(compiler, connection) return super().process_rhs(compiler, connection)

View File

@ -1,7 +1,9 @@
import copy import copy
import datetime import datetime
import functools
import inspect import inspect
from decimal import Decimal from decimal import Decimal
from uuid import UUID
from django.core.exceptions import EmptyResultSet, FieldError from django.core.exceptions import EmptyResultSet, FieldError
from django.db import NotSupportedError, connection from django.db import NotSupportedError, connection
@ -56,12 +58,7 @@ class Combinable:
def _combine(self, other, connector, reversed): def _combine(self, other, connector, reversed):
if not hasattr(other, 'resolve_expression'): if not hasattr(other, 'resolve_expression'):
# everything must be resolvable to an expression # everything must be resolvable to an expression
output_field = ( other = Value(other)
fields.DurationField()
if isinstance(other, datetime.timedelta) else
None
)
other = Value(other, output_field=output_field)
if reversed: if reversed:
return CombinedExpression(other, connector, self) return CombinedExpression(other, connector, self)
@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
pass pass
_connector_combinators = {
connector: [
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField),
]
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
}
@functools.lru_cache(maxsize=128)
def _resolve_combined_type(connector, lhs_type, rhs_type):
combinators = _connector_combinators.get(connector, ())
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
return combined_type
class CombinedExpression(SQLiteNumericMixin, Expression): class CombinedExpression(SQLiteNumericMixin, Expression):
def __init__(self, lhs, connector, rhs, output_field=None): def __init__(self, lhs, connector, rhs, output_field=None):
@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
def set_source_expressions(self, exprs): def set_source_expressions(self, exprs):
self.lhs, self.rhs = exprs self.lhs, self.rhs = exprs
def _resolve_output_field(self):
try:
return super()._resolve_output_field()
except FieldError:
combined_type = _resolve_combined_type(
self.connector,
type(self.lhs.output_field),
type(self.rhs.output_field),
)
if combined_type is None:
raise
return combined_type()
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
expressions = [] expressions = []
expression_params = [] expression_params = []
@ -721,6 +750,30 @@ class Value(Expression):
def get_group_by_cols(self, alias=None): def get_group_by_cols(self, alias=None):
return [] return []
def _resolve_output_field(self):
if isinstance(self.value, str):
return fields.CharField()
if isinstance(self.value, bool):
return fields.BooleanField()
if isinstance(self.value, int):
return fields.IntegerField()
if isinstance(self.value, float):
return fields.FloatField()
if isinstance(self.value, datetime.datetime):
return fields.DateTimeField()
if isinstance(self.value, datetime.date):
return fields.DateField()
if isinstance(self.value, datetime.time):
return fields.TimeField()
if isinstance(self.value, datetime.timedelta):
return fields.DurationField()
if isinstance(self.value, Decimal):
return fields.DecimalField()
if isinstance(self.value, bytes):
return fields.BinaryField()
if isinstance(self.value, UUID):
return fields.UUIDField()
class RawSQL(Expression): class RawSQL(Expression):
def __init__(self, sql, params, output_field=None): def __init__(self, sql, params, output_field=None):
@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
copy.expression = Case( copy.expression = Case(
When(self.expression, then=True), When(self.expression, then=True),
default=False, default=False,
output_field=fields.BooleanField(),
) )
return copy.as_sql(compiler, connection) return copy.as_sql(compiler, connection)
return self.as_sql(compiler, connection) return self.as_sql(compiler, connection)

View File

@ -6,7 +6,7 @@ from copy import copy
from django.core.exceptions import EmptyResultSet from django.core.exceptions import EmptyResultSet
from django.db.models.expressions import Case, Exists, Func, Value, When from django.db.models.expressions import Case, Exists, Func, Value, When
from django.db.models.fields import ( from django.db.models.fields import (
BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField, CharField, DateTimeField, Field, IntegerField, UUIDField,
) )
from django.db.models.query_utils import RegisterLookupMixin from django.db.models.query_utils import RegisterLookupMixin
from django.utils.datastructures import OrderedSet from django.utils.datastructures import OrderedSet
@ -123,7 +123,7 @@ class Lookup:
exprs = [] exprs = []
for expr in (self.lhs, self.rhs): for expr in (self.lhs, self.rhs):
if isinstance(expr, Exists): if isinstance(expr, Exists):
expr = Case(When(expr, then=True), default=False, output_field=BooleanField()) expr = Case(When(expr, then=True), default=False)
wrapped = True wrapped = True
exprs.append(expr) exprs.append(expr)
lookup = type(self)(*exprs) if wrapped else self lookup = type(self)(*exprs) if wrapped else self

View File

@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like
after it's retrieved from the database. Usually no arguments are needed when after it's retrieved from the database. Usually no arguments are needed when
instantiating the model field as any arguments relating to data validation instantiating the model field as any arguments relating to data validation
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's (``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
output value. output value. If no ``output_field`` is specified it will be tentatively
inferred from the :py:class:`type` of the provided ``value``, if possible. For
example, passing an instance of :py:class:`datetime.datetime` as ``value``
would default ``output_field`` to :class:`~django.db.models.DateTimeField`.
.. versionchanged:: 3.2
Support for inferring a default ``output_field`` from the type of ``value``
was added.
``ExpressionWrapper()`` expressions ``ExpressionWrapper()`` expressions
----------------------------------- -----------------------------------

View File

@ -233,6 +233,15 @@ Models
* The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed * The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed
on MySQL 8.0.1+. on MySQL 8.0.1+.
* :class:`Value() <django.db.models.Value>` expression now
automatically resolves its ``output_field`` to the appropriate
:class:`Field <django.db.models.Field>` subclass based on the type of
it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`,
:py:class:`float`, :py:class:`int`, :py:class:`str`,
:py:class:`datetime.date`, :py:class:`datetime.datetime`,
:py:class:`datetime.time`, :py:class:`datetime.timedelta`,
:py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -848,10 +848,6 @@ class AggregateTestCase(TestCase):
book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first() book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first()
self.assertEqual(book.val, 2) self.assertEqual(book.val, 2)
def test_missing_output_field_raises_error(self):
with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'):
Book.objects.annotate(val=Max(2)).first()
def test_annotation_expressions(self): def test_annotation_expressions(self):
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name') authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
@ -893,7 +889,7 @@ class AggregateTestCase(TestCase):
def test_combine_different_types(self): def test_combine_different_types(self):
msg = ( msg = (
'Expression contains mixed types: FloatField, IntegerField. ' 'Expression contains mixed types: FloatField, DecimalField. '
'You must set output_field.' 'You must set output_field.'
) )
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')) qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))

View File

@ -388,7 +388,7 @@ class AggregationTests(TestCase):
) )
def test_annotated_conditional_aggregate(self): def test_annotated_conditional_aggregate(self):
annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75) annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75'))
self.assertAlmostEqual( self.assertAlmostEqual(
annotated_qs.aggregate(test=Avg(Case( annotated_qs.aggregate(test=Avg(Case(
When(pages__lt=400, then='discount_price'), When(pages__lt=400, then='discount_price'),

View File

@ -3,15 +3,17 @@ import pickle
import unittest import unittest
import uuid import uuid
from copy import deepcopy from copy import deepcopy
from decimal import Decimal
from unittest import mock from unittest import mock
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import DatabaseError, NotSupportedError, connection from django.db import DatabaseError, NotSupportedError, connection
from django.db.models import ( from django.db.models import (
Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField, Avg, BinaryField, BooleanField, Case, CharField, Count, DateField,
DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F, DateTimeField, DecimalField, DurationField, Exists, Expression,
Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev, ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
Subquery, Sum, TimeField, UUIDField, Value, Variance, When, Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField,
UUIDField, Value, Variance, When,
) )
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
from django.db.models.functions import ( from django.db.models.functions import (
@ -1711,6 +1713,30 @@ class ValueTests(TestCase):
value = Value('foo', output_field=CharField()) value = Value('foo', output_field=CharField())
self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo'])) self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo']))
def test_resolve_output_field(self):
value_types = [
('str', CharField),
(True, BooleanField),
(42, IntegerField),
(3.14, FloatField),
(datetime.date(2019, 5, 15), DateField),
(datetime.datetime(2019, 5, 15), DateTimeField),
(datetime.time(3, 16), TimeField),
(datetime.timedelta(1), DurationField),
(Decimal('3.14'), DecimalField),
(b'', BinaryField),
(uuid.uuid4(), UUIDField),
]
for value, ouput_field_type in value_types:
with self.subTest(type=type(value)):
expr = Value(value)
self.assertIsInstance(expr.output_field, ouput_field_type)
def test_resolve_output_field_failure(self):
msg = 'Cannot resolve expression type, unknown output_field'
with self.assertRaisesMessage(FieldError, msg):
Value(object()).output_field
class FieldTransformTests(TestCase): class FieldTransformTests(TestCase):
@ -1848,7 +1874,9 @@ class ExpressionWrapperTests(SimpleTestCase):
self.assertEqual(expr.get_group_by_cols(alias=None), []) self.assertEqual(expr.get_group_by_cols(alias=None), [])
def test_non_empty_group_by(self): def test_non_empty_group_by(self):
expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField()) value = Value('f')
value.output_field = None
expr = ExpressionWrapper(Lower(value), output_field=IntegerField())
group_by_cols = expr.get_group_by_cols(alias=None) group_by_cols = expr.get_group_by_cols(alias=None)
self.assertEqual(group_by_cols, [expr.expression]) self.assertEqual(group_by_cols, [expr.expression])
self.assertEqual(group_by_cols[0].output_field, expr.output_field) self.assertEqual(group_by_cols[0].output_field, expr.output_field)

View File

@ -1,7 +1,6 @@
from datetime import datetime from datetime import datetime
from operator import attrgetter from operator import attrgetter
from django.core.exceptions import FieldError
from django.db.models import ( from django.db.models import (
CharField, DateTimeField, F, Max, OuterRef, Subquery, Value, CharField, DateTimeField, F, Max, OuterRef, Subquery, Value,
) )
@ -439,17 +438,6 @@ class OrderingTests(TestCase):
qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline') qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline')
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1]) self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])
def test_order_by_constant_value_without_output_field(self):
msg = 'Cannot resolve expression type, unknown output_field'
qs = Article.objects.annotate(constant=Value('1')).order_by('constant')
for ordered_qs in (
qs,
qs.values('headline'),
Article.objects.order_by(Value('1')),
):
with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg):
ordered_qs.first()
def test_related_ordering_duplicate_table_reference(self): def test_related_ordering_duplicate_table_reference(self):
""" """
An ordering referencing a model with an ordering referencing a model An ordering referencing a model with an ordering referencing a model