mirror of https://github.com/django/django.git
Fixed #30446 -- Resolved Value.output_field for stdlib types.
This required implementing a limited form of dynamic dispatch to combine expressions with numerical output. Refs #26355 should eventually provide a better interface for that.
This commit is contained in:
parent
d08e6f55e3
commit
1e38f1191d
|
@ -101,10 +101,13 @@ class SQLiteDecimalToFloatMixin:
|
||||||
is not acceptable by the GIS functions expecting numeric values.
|
is not acceptable by the GIS functions expecting numeric values.
|
||||||
"""
|
"""
|
||||||
def as_sqlite(self, compiler, connection, **extra_context):
|
def as_sqlite(self, compiler, connection, **extra_context):
|
||||||
for expr in self.get_source_expressions():
|
copy = self.copy()
|
||||||
if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
|
copy.set_source_expressions([
|
||||||
expr.value = float(expr.value)
|
Value(float(expr.value)) if hasattr(expr, 'value') and isinstance(expr.value, Decimal)
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
else expr
|
||||||
|
for expr in copy.get_source_expressions()
|
||||||
|
])
|
||||||
|
return copy.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class OracleToleranceMixin:
|
class OracleToleranceMixin:
|
||||||
|
|
|
@ -173,8 +173,7 @@ class DateTimeRangeContains(PostgresOperatorLookup):
|
||||||
def process_rhs(self, compiler, connection):
|
def process_rhs(self, compiler, connection):
|
||||||
# Transform rhs value for db lookup.
|
# Transform rhs value for db lookup.
|
||||||
if isinstance(self.rhs, datetime.date):
|
if isinstance(self.rhs, datetime.date):
|
||||||
output_field = models.DateTimeField() if isinstance(self.rhs, datetime.datetime) else models.DateField()
|
value = models.Value(self.rhs)
|
||||||
value = models.Value(self.rhs, output_field=output_field)
|
|
||||||
self.rhs = value.resolve_expression(compiler.query)
|
self.rhs = value.resolve_expression(compiler.query)
|
||||||
return super().process_rhs(compiler, connection)
|
return super().process_rhs(compiler, connection)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from django.core.exceptions import EmptyResultSet, FieldError
|
from django.core.exceptions import EmptyResultSet, FieldError
|
||||||
from django.db import NotSupportedError, connection
|
from django.db import NotSupportedError, connection
|
||||||
|
@ -56,12 +58,7 @@ class Combinable:
|
||||||
def _combine(self, other, connector, reversed):
|
def _combine(self, other, connector, reversed):
|
||||||
if not hasattr(other, 'resolve_expression'):
|
if not hasattr(other, 'resolve_expression'):
|
||||||
# everything must be resolvable to an expression
|
# everything must be resolvable to an expression
|
||||||
output_field = (
|
other = Value(other)
|
||||||
fields.DurationField()
|
|
||||||
if isinstance(other, datetime.timedelta) else
|
|
||||||
None
|
|
||||||
)
|
|
||||||
other = Value(other, output_field=output_field)
|
|
||||||
|
|
||||||
if reversed:
|
if reversed:
|
||||||
return CombinedExpression(other, connector, self)
|
return CombinedExpression(other, connector, self)
|
||||||
|
@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_connector_combinators = {
|
||||||
|
connector: [
|
||||||
|
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
|
||||||
|
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
|
||||||
|
(fields.IntegerField, fields.FloatField, fields.FloatField),
|
||||||
|
(fields.FloatField, fields.IntegerField, fields.FloatField),
|
||||||
|
]
|
||||||
|
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def _resolve_combined_type(connector, lhs_type, rhs_type):
|
||||||
|
combinators = _connector_combinators.get(connector, ())
|
||||||
|
for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
|
||||||
|
if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
|
||||||
|
return combined_type
|
||||||
|
|
||||||
|
|
||||||
class CombinedExpression(SQLiteNumericMixin, Expression):
|
class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||||
|
|
||||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||||
|
@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||||
def set_source_expressions(self, exprs):
|
def set_source_expressions(self, exprs):
|
||||||
self.lhs, self.rhs = exprs
|
self.lhs, self.rhs = exprs
|
||||||
|
|
||||||
|
def _resolve_output_field(self):
|
||||||
|
try:
|
||||||
|
return super()._resolve_output_field()
|
||||||
|
except FieldError:
|
||||||
|
combined_type = _resolve_combined_type(
|
||||||
|
self.connector,
|
||||||
|
type(self.lhs.output_field),
|
||||||
|
type(self.rhs.output_field),
|
||||||
|
)
|
||||||
|
if combined_type is None:
|
||||||
|
raise
|
||||||
|
return combined_type()
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
expressions = []
|
expressions = []
|
||||||
expression_params = []
|
expression_params = []
|
||||||
|
@ -721,6 +750,30 @@ class Value(Expression):
|
||||||
def get_group_by_cols(self, alias=None):
|
def get_group_by_cols(self, alias=None):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _resolve_output_field(self):
|
||||||
|
if isinstance(self.value, str):
|
||||||
|
return fields.CharField()
|
||||||
|
if isinstance(self.value, bool):
|
||||||
|
return fields.BooleanField()
|
||||||
|
if isinstance(self.value, int):
|
||||||
|
return fields.IntegerField()
|
||||||
|
if isinstance(self.value, float):
|
||||||
|
return fields.FloatField()
|
||||||
|
if isinstance(self.value, datetime.datetime):
|
||||||
|
return fields.DateTimeField()
|
||||||
|
if isinstance(self.value, datetime.date):
|
||||||
|
return fields.DateField()
|
||||||
|
if isinstance(self.value, datetime.time):
|
||||||
|
return fields.TimeField()
|
||||||
|
if isinstance(self.value, datetime.timedelta):
|
||||||
|
return fields.DurationField()
|
||||||
|
if isinstance(self.value, Decimal):
|
||||||
|
return fields.DecimalField()
|
||||||
|
if isinstance(self.value, bytes):
|
||||||
|
return fields.BinaryField()
|
||||||
|
if isinstance(self.value, UUID):
|
||||||
|
return fields.UUIDField()
|
||||||
|
|
||||||
|
|
||||||
class RawSQL(Expression):
|
class RawSQL(Expression):
|
||||||
def __init__(self, sql, params, output_field=None):
|
def __init__(self, sql, params, output_field=None):
|
||||||
|
@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
|
||||||
copy.expression = Case(
|
copy.expression = Case(
|
||||||
When(self.expression, then=True),
|
When(self.expression, then=True),
|
||||||
default=False,
|
default=False,
|
||||||
output_field=fields.BooleanField(),
|
|
||||||
)
|
)
|
||||||
return copy.as_sql(compiler, connection)
|
return copy.as_sql(compiler, connection)
|
||||||
return self.as_sql(compiler, connection)
|
return self.as_sql(compiler, connection)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from copy import copy
|
||||||
from django.core.exceptions import EmptyResultSet
|
from django.core.exceptions import EmptyResultSet
|
||||||
from django.db.models.expressions import Case, Exists, Func, Value, When
|
from django.db.models.expressions import Case, Exists, Func, Value, When
|
||||||
from django.db.models.fields import (
|
from django.db.models.fields import (
|
||||||
BooleanField, CharField, DateTimeField, Field, IntegerField, UUIDField,
|
CharField, DateTimeField, Field, IntegerField, UUIDField,
|
||||||
)
|
)
|
||||||
from django.db.models.query_utils import RegisterLookupMixin
|
from django.db.models.query_utils import RegisterLookupMixin
|
||||||
from django.utils.datastructures import OrderedSet
|
from django.utils.datastructures import OrderedSet
|
||||||
|
@ -123,7 +123,7 @@ class Lookup:
|
||||||
exprs = []
|
exprs = []
|
||||||
for expr in (self.lhs, self.rhs):
|
for expr in (self.lhs, self.rhs):
|
||||||
if isinstance(expr, Exists):
|
if isinstance(expr, Exists):
|
||||||
expr = Case(When(expr, then=True), default=False, output_field=BooleanField())
|
expr = Case(When(expr, then=True), default=False)
|
||||||
wrapped = True
|
wrapped = True
|
||||||
exprs.append(expr)
|
exprs.append(expr)
|
||||||
lookup = type(self)(*exprs) if wrapped else self
|
lookup = type(self)(*exprs) if wrapped else self
|
||||||
|
|
|
@ -484,7 +484,15 @@ The ``output_field`` argument should be a model field instance, like
|
||||||
after it's retrieved from the database. Usually no arguments are needed when
|
after it's retrieved from the database. Usually no arguments are needed when
|
||||||
instantiating the model field as any arguments relating to data validation
|
instantiating the model field as any arguments relating to data validation
|
||||||
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
|
(``max_length``, ``max_digits``, etc.) will not be enforced on the expression's
|
||||||
output value.
|
output value. If no ``output_field`` is specified it will be tentatively
|
||||||
|
inferred from the :py:class:`type` of the provided ``value``, if possible. For
|
||||||
|
example, passing an instance of :py:class:`datetime.datetime` as ``value``
|
||||||
|
would default ``output_field`` to :class:`~django.db.models.DateTimeField`.
|
||||||
|
|
||||||
|
.. versionchanged:: 3.2
|
||||||
|
|
||||||
|
Support for inferring a default ``output_field`` from the type of ``value``
|
||||||
|
was added.
|
||||||
|
|
||||||
``ExpressionWrapper()`` expressions
|
``ExpressionWrapper()`` expressions
|
||||||
-----------------------------------
|
-----------------------------------
|
||||||
|
|
|
@ -233,6 +233,15 @@ Models
|
||||||
* The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed
|
* The ``of`` argument of :meth:`.QuerySet.select_for_update()` is now allowed
|
||||||
on MySQL 8.0.1+.
|
on MySQL 8.0.1+.
|
||||||
|
|
||||||
|
* :class:`Value() <django.db.models.Value>` expression now
|
||||||
|
automatically resolves its ``output_field`` to the appropriate
|
||||||
|
:class:`Field <django.db.models.Field>` subclass based on the type of
|
||||||
|
it's provided ``value`` for :py:class:`bool`, :py:class:`bytes`,
|
||||||
|
:py:class:`float`, :py:class:`int`, :py:class:`str`,
|
||||||
|
:py:class:`datetime.date`, :py:class:`datetime.datetime`,
|
||||||
|
:py:class:`datetime.time`, :py:class:`datetime.timedelta`,
|
||||||
|
:py:class:`decimal.Decimal`, and :py:class:`uuid.UUID` instances.
|
||||||
|
|
||||||
Requests and Responses
|
Requests and Responses
|
||||||
~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
|
@ -848,10 +848,6 @@ class AggregateTestCase(TestCase):
|
||||||
book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first()
|
book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first()
|
||||||
self.assertEqual(book.val, 2)
|
self.assertEqual(book.val, 2)
|
||||||
|
|
||||||
def test_missing_output_field_raises_error(self):
|
|
||||||
with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'):
|
|
||||||
Book.objects.annotate(val=Max(2)).first()
|
|
||||||
|
|
||||||
def test_annotation_expressions(self):
|
def test_annotation_expressions(self):
|
||||||
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
|
authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name')
|
||||||
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
|
authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name')
|
||||||
|
@ -893,7 +889,7 @@ class AggregateTestCase(TestCase):
|
||||||
|
|
||||||
def test_combine_different_types(self):
|
def test_combine_different_types(self):
|
||||||
msg = (
|
msg = (
|
||||||
'Expression contains mixed types: FloatField, IntegerField. '
|
'Expression contains mixed types: FloatField, DecimalField. '
|
||||||
'You must set output_field.'
|
'You must set output_field.'
|
||||||
)
|
)
|
||||||
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))
|
qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price'))
|
||||||
|
|
|
@ -388,7 +388,7 @@ class AggregationTests(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_annotated_conditional_aggregate(self):
|
def test_annotated_conditional_aggregate(self):
|
||||||
annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75)
|
annotated_qs = Book.objects.annotate(discount_price=F('price') * Decimal('0.75'))
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
annotated_qs.aggregate(test=Avg(Case(
|
annotated_qs.aggregate(test=Avg(Case(
|
||||||
When(pages__lt=400, then='discount_price'),
|
When(pages__lt=400, then='discount_price'),
|
||||||
|
|
|
@ -3,15 +3,17 @@ import pickle
|
||||||
import unittest
|
import unittest
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from decimal import Decimal
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import DatabaseError, NotSupportedError, connection
|
from django.db import DatabaseError, NotSupportedError, connection
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
Avg, BooleanField, Case, CharField, Count, DateField, DateTimeField,
|
Avg, BinaryField, BooleanField, Case, CharField, Count, DateField,
|
||||||
DurationField, Exists, Expression, ExpressionList, ExpressionWrapper, F,
|
DateTimeField, DecimalField, DurationField, Exists, Expression,
|
||||||
Func, IntegerField, Max, Min, Model, OrderBy, OuterRef, Q, StdDev,
|
ExpressionList, ExpressionWrapper, F, FloatField, Func, IntegerField, Max,
|
||||||
Subquery, Sum, TimeField, UUIDField, Value, Variance, When,
|
Min, Model, OrderBy, OuterRef, Q, StdDev, Subquery, Sum, TimeField,
|
||||||
|
UUIDField, Value, Variance, When,
|
||||||
)
|
)
|
||||||
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
|
from django.db.models.expressions import Col, Combinable, Random, RawSQL, Ref
|
||||||
from django.db.models.functions import (
|
from django.db.models.functions import (
|
||||||
|
@ -1711,6 +1713,30 @@ class ValueTests(TestCase):
|
||||||
value = Value('foo', output_field=CharField())
|
value = Value('foo', output_field=CharField())
|
||||||
self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo']))
|
self.assertEqual(value.as_sql(compiler, connection), ('%s', ['foo']))
|
||||||
|
|
||||||
|
def test_resolve_output_field(self):
|
||||||
|
value_types = [
|
||||||
|
('str', CharField),
|
||||||
|
(True, BooleanField),
|
||||||
|
(42, IntegerField),
|
||||||
|
(3.14, FloatField),
|
||||||
|
(datetime.date(2019, 5, 15), DateField),
|
||||||
|
(datetime.datetime(2019, 5, 15), DateTimeField),
|
||||||
|
(datetime.time(3, 16), TimeField),
|
||||||
|
(datetime.timedelta(1), DurationField),
|
||||||
|
(Decimal('3.14'), DecimalField),
|
||||||
|
(b'', BinaryField),
|
||||||
|
(uuid.uuid4(), UUIDField),
|
||||||
|
]
|
||||||
|
for value, ouput_field_type in value_types:
|
||||||
|
with self.subTest(type=type(value)):
|
||||||
|
expr = Value(value)
|
||||||
|
self.assertIsInstance(expr.output_field, ouput_field_type)
|
||||||
|
|
||||||
|
def test_resolve_output_field_failure(self):
|
||||||
|
msg = 'Cannot resolve expression type, unknown output_field'
|
||||||
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
Value(object()).output_field
|
||||||
|
|
||||||
|
|
||||||
class FieldTransformTests(TestCase):
|
class FieldTransformTests(TestCase):
|
||||||
|
|
||||||
|
@ -1848,7 +1874,9 @@ class ExpressionWrapperTests(SimpleTestCase):
|
||||||
self.assertEqual(expr.get_group_by_cols(alias=None), [])
|
self.assertEqual(expr.get_group_by_cols(alias=None), [])
|
||||||
|
|
||||||
def test_non_empty_group_by(self):
|
def test_non_empty_group_by(self):
|
||||||
expr = ExpressionWrapper(Lower(Value('f')), output_field=IntegerField())
|
value = Value('f')
|
||||||
|
value.output_field = None
|
||||||
|
expr = ExpressionWrapper(Lower(value), output_field=IntegerField())
|
||||||
group_by_cols = expr.get_group_by_cols(alias=None)
|
group_by_cols = expr.get_group_by_cols(alias=None)
|
||||||
self.assertEqual(group_by_cols, [expr.expression])
|
self.assertEqual(group_by_cols, [expr.expression])
|
||||||
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
|
self.assertEqual(group_by_cols[0].output_field, expr.output_field)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
|
||||||
from django.db.models import (
|
from django.db.models import (
|
||||||
CharField, DateTimeField, F, Max, OuterRef, Subquery, Value,
|
CharField, DateTimeField, F, Max, OuterRef, Subquery, Value,
|
||||||
)
|
)
|
||||||
|
@ -439,17 +438,6 @@ class OrderingTests(TestCase):
|
||||||
qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline')
|
qs = Article.objects.order_by(Value('1', output_field=CharField()), '-headline')
|
||||||
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])
|
self.assertSequenceEqual(qs, [self.a4, self.a3, self.a2, self.a1])
|
||||||
|
|
||||||
def test_order_by_constant_value_without_output_field(self):
|
|
||||||
msg = 'Cannot resolve expression type, unknown output_field'
|
|
||||||
qs = Article.objects.annotate(constant=Value('1')).order_by('constant')
|
|
||||||
for ordered_qs in (
|
|
||||||
qs,
|
|
||||||
qs.values('headline'),
|
|
||||||
Article.objects.order_by(Value('1')),
|
|
||||||
):
|
|
||||||
with self.subTest(ordered_qs=ordered_qs), self.assertRaisesMessage(FieldError, msg):
|
|
||||||
ordered_qs.first()
|
|
||||||
|
|
||||||
def test_related_ordering_duplicate_table_reference(self):
|
def test_related_ordering_duplicate_table_reference(self):
|
||||||
"""
|
"""
|
||||||
An ordering referencing a model with an ordering referencing a model
|
An ordering referencing a model with an ordering referencing a model
|
||||||
|
|
Loading…
Reference in New Issue