Fixed #33397 -- Corrected resolving output_field for DateField/DateTimeField/TimeField/DurationFields.
This includes refactoring of CombinedExpression._resolve_output_field() so it no longer uses the behavior inherited from Expression of guessing same output type if argument types match, and instead we explicitly define the output type of all supported operations. This also makes nonsensical operations involving dates (e.g. date + date) raise a FieldError, and adds support for automatically inferring output_field for cases such as: * date - date * date + duration * date - duration * time + duration * time - time
This commit is contained in:
parent
1efea11808
commit
40b8a6174f
|
@ -10,7 +10,7 @@ from django.db.models import (
|
||||||
TextField,
|
TextField,
|
||||||
Value,
|
Value,
|
||||||
)
|
)
|
||||||
from django.db.models.expressions import CombinedExpression
|
from django.db.models.expressions import CombinedExpression, register_combinable_fields
|
||||||
from django.db.models.functions import Cast, Coalesce
|
from django.db.models.functions import Cast, Coalesce
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,6 +79,11 @@ class SearchVectorCombinable:
|
||||||
return CombinedSearchVector(self, connector, other, self.config)
|
return CombinedSearchVector(self, connector, other, self.config)
|
||||||
|
|
||||||
|
|
||||||
|
register_combinable_fields(
|
||||||
|
SearchVectorField, SearchVectorCombinable.ADD, SearchVectorField, SearchVectorField
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SearchVector(SearchVectorCombinable, Func):
|
class SearchVector(SearchVectorCombinable, Func):
|
||||||
function = "to_tsvector"
|
function = "to_tsvector"
|
||||||
arg_joiner = " || ' ' || "
|
arg_joiner = " || ' ' || "
|
||||||
|
|
|
@ -310,18 +310,18 @@ class BaseExpression:
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
"""
|
"""
|
||||||
Attempt to infer the output type of the expression. If the output
|
Attempt to infer the output type of the expression.
|
||||||
fields of all source fields match then, simply infer the same type
|
|
||||||
here. This isn't always correct, but it makes sense most of the time.
|
|
||||||
|
|
||||||
Consider the difference between `2 + 2` and `2 / 3`. Inferring
|
As a guess, if the output fields of all source fields match then simply
|
||||||
the type here is a convenience for the common case. The user should
|
infer the same type here.
|
||||||
supply their own output_field with more complex computations.
|
|
||||||
|
|
||||||
If a source's output field resolves to None, exclude it from this check.
|
If a source's output field resolves to None, exclude it from this check.
|
||||||
If all sources are None, then an error is raised higher up the stack in
|
If all sources are None, then an error is raised higher up the stack in
|
||||||
the output_field property.
|
the output_field property.
|
||||||
"""
|
"""
|
||||||
|
# This guess is mostly a bad idea, but there is quite a lot of code
|
||||||
|
# (especially 3rd party Func subclasses) that depend on it, we'd need a
|
||||||
|
# deprecation path to fix it.
|
||||||
sources_iter = (
|
sources_iter = (
|
||||||
source for source in self.get_source_fields() if source is not None
|
source for source in self.get_source_fields() if source is not None
|
||||||
)
|
)
|
||||||
|
@ -467,6 +467,13 @@ class Expression(BaseExpression, Combinable):
|
||||||
|
|
||||||
|
|
||||||
# Type inference for CombinedExpression.output_field.
|
# Type inference for CombinedExpression.output_field.
|
||||||
|
# Missing items will result in FieldError, by design.
|
||||||
|
#
|
||||||
|
# The current approach for NULL is based on lowest common denominator behavior
|
||||||
|
# i.e. if one of the supported databases is raising an error (rather than
|
||||||
|
# return NULL) for `val <op> NULL`, then Django raises FieldError.
|
||||||
|
NoneType = type(None)
|
||||||
|
|
||||||
_connector_combinations = [
|
_connector_combinations = [
|
||||||
# Numeric operations - operands of same type.
|
# Numeric operations - operands of same type.
|
||||||
{
|
{
|
||||||
|
@ -482,6 +489,8 @@ _connector_combinations = [
|
||||||
# Behavior for DIV with integer arguments follows Postgres/SQLite,
|
# Behavior for DIV with integer arguments follows Postgres/SQLite,
|
||||||
# not MySQL/Oracle.
|
# not MySQL/Oracle.
|
||||||
Combinable.DIV,
|
Combinable.DIV,
|
||||||
|
Combinable.MOD,
|
||||||
|
Combinable.POW,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
# Numeric operations - operands of different type.
|
# Numeric operations - operands of different type.
|
||||||
|
@ -499,6 +508,66 @@ _connector_combinations = [
|
||||||
Combinable.DIV,
|
Combinable.DIV,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
# Bitwise operators.
|
||||||
|
{
|
||||||
|
connector: [
|
||||||
|
(fields.IntegerField, fields.IntegerField, fields.IntegerField),
|
||||||
|
]
|
||||||
|
for connector in (
|
||||||
|
Combinable.BITAND,
|
||||||
|
Combinable.BITOR,
|
||||||
|
Combinable.BITLEFTSHIFT,
|
||||||
|
Combinable.BITRIGHTSHIFT,
|
||||||
|
Combinable.BITXOR,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
# Numeric with NULL.
|
||||||
|
{
|
||||||
|
connector: [
|
||||||
|
(field_type, NoneType, field_type),
|
||||||
|
(NoneType, field_type, field_type),
|
||||||
|
]
|
||||||
|
for connector in (
|
||||||
|
Combinable.ADD,
|
||||||
|
Combinable.SUB,
|
||||||
|
Combinable.MUL,
|
||||||
|
Combinable.DIV,
|
||||||
|
Combinable.MOD,
|
||||||
|
Combinable.POW,
|
||||||
|
)
|
||||||
|
for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
|
||||||
|
},
|
||||||
|
# Date/DateTimeField/DurationField/TimeField.
|
||||||
|
{
|
||||||
|
Combinable.ADD: [
|
||||||
|
# Date/DateTimeField.
|
||||||
|
(fields.DateField, fields.DurationField, fields.DateTimeField),
|
||||||
|
(fields.DateTimeField, fields.DurationField, fields.DateTimeField),
|
||||||
|
(fields.DurationField, fields.DateField, fields.DateTimeField),
|
||||||
|
(fields.DurationField, fields.DateTimeField, fields.DateTimeField),
|
||||||
|
# DurationField.
|
||||||
|
(fields.DurationField, fields.DurationField, fields.DurationField),
|
||||||
|
# TimeField.
|
||||||
|
(fields.TimeField, fields.DurationField, fields.TimeField),
|
||||||
|
(fields.DurationField, fields.TimeField, fields.TimeField),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Combinable.SUB: [
|
||||||
|
# Date/DateTimeField.
|
||||||
|
(fields.DateField, fields.DurationField, fields.DateTimeField),
|
||||||
|
(fields.DateTimeField, fields.DurationField, fields.DateTimeField),
|
||||||
|
(fields.DateField, fields.DateField, fields.DurationField),
|
||||||
|
(fields.DateField, fields.DateTimeField, fields.DurationField),
|
||||||
|
(fields.DateTimeField, fields.DateField, fields.DurationField),
|
||||||
|
(fields.DateTimeField, fields.DateTimeField, fields.DurationField),
|
||||||
|
# DurationField.
|
||||||
|
(fields.DurationField, fields.DurationField, fields.DurationField),
|
||||||
|
# TimeField.
|
||||||
|
(fields.TimeField, fields.DurationField, fields.TimeField),
|
||||||
|
(fields.TimeField, fields.TimeField, fields.DurationField),
|
||||||
|
],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
_connector_combinators = defaultdict(list)
|
_connector_combinators = defaultdict(list)
|
||||||
|
@ -552,17 +621,21 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
||||||
self.lhs, self.rhs = exprs
|
self.lhs, self.rhs = exprs
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
def _resolve_output_field(self):
|
||||||
try:
|
# We avoid using super() here for reasons given in
|
||||||
return super()._resolve_output_field()
|
# Expression._resolve_output_field()
|
||||||
except FieldError:
|
combined_type = _resolve_combined_type(
|
||||||
combined_type = _resolve_combined_type(
|
self.connector,
|
||||||
self.connector,
|
type(self.lhs._output_field_or_none),
|
||||||
type(self.lhs.output_field),
|
type(self.rhs._output_field_or_none),
|
||||||
type(self.rhs.output_field),
|
)
|
||||||
|
if combined_type is None:
|
||||||
|
raise FieldError(
|
||||||
|
f"Cannot infer type of {self.connector!r} expression involving these "
|
||||||
|
f"types: {self.lhs.output_field.__class__.__name__}, "
|
||||||
|
f"{self.rhs.output_field.__class__.__name__}. You must set "
|
||||||
|
f"output_field."
|
||||||
)
|
)
|
||||||
if combined_type is None:
|
return combined_type()
|
||||||
raise
|
|
||||||
return combined_type()
|
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
expressions = []
|
expressions = []
|
||||||
|
|
|
@ -1126,8 +1126,8 @@ class AggregateTestCase(TestCase):
|
||||||
|
|
||||||
def test_combine_different_types(self):
|
def test_combine_different_types(self):
|
||||||
msg = (
|
msg = (
|
||||||
"Expression contains mixed types: FloatField, DecimalField. "
|
"Cannot infer type of '+' expression involving these types: FloatField, "
|
||||||
"You must set output_field."
|
"DecimalField. You must set output_field."
|
||||||
)
|
)
|
||||||
qs = Book.objects.annotate(sums=Sum("rating") + Sum("pages") + Sum("price"))
|
qs = Book.objects.annotate(sums=Sum("rating") + Sum("pages") + Sum("price"))
|
||||||
with self.assertRaisesMessage(FieldError, msg):
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
|
|
|
@ -1736,6 +1736,17 @@ class FTimeDeltaTests(TestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_datetime_and_duration_field_addition_with_annotate_and_no_output_field(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
test_set = Experiment.objects.annotate(
|
||||||
|
estimated_end=F("start") + F("estimated_time")
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
[e.estimated_end for e in test_set],
|
||||||
|
[e.start + e.estimated_time for e in test_set],
|
||||||
|
)
|
||||||
|
|
||||||
@skipUnlessDBFeature("supports_temporal_subtraction")
|
@skipUnlessDBFeature("supports_temporal_subtraction")
|
||||||
def test_datetime_subtraction_with_annotate_and_no_output_field(self):
|
def test_datetime_subtraction_with_annotate_and_no_output_field(self):
|
||||||
test_set = Experiment.objects.annotate(
|
test_set = Experiment.objects.annotate(
|
||||||
|
@ -2438,8 +2449,10 @@ class CombinedExpressionTests(SimpleTestCase):
|
||||||
(null, Combinable.ADD, DateTimeField),
|
(null, Combinable.ADD, DateTimeField),
|
||||||
(DateField, Combinable.SUB, null),
|
(DateField, Combinable.SUB, null),
|
||||||
]
|
]
|
||||||
msg = "Expression contains mixed types: "
|
|
||||||
for lhs, connector, rhs in tests:
|
for lhs, connector, rhs in tests:
|
||||||
|
msg = (
|
||||||
|
f"Cannot infer type of {connector!r} expression involving these types: "
|
||||||
|
)
|
||||||
with self.subTest(lhs=lhs, connector=connector, rhs=rhs):
|
with self.subTest(lhs=lhs, connector=connector, rhs=rhs):
|
||||||
expr = CombinedExpression(
|
expr = CombinedExpression(
|
||||||
Expression(lhs()),
|
Expression(lhs()),
|
||||||
|
@ -2452,16 +2465,34 @@ class CombinedExpressionTests(SimpleTestCase):
|
||||||
def test_resolve_output_field_dates(self):
|
def test_resolve_output_field_dates(self):
|
||||||
tests = [
|
tests = [
|
||||||
# Add - same type.
|
# Add - same type.
|
||||||
|
(DateField, Combinable.ADD, DateField, FieldError),
|
||||||
|
(DateTimeField, Combinable.ADD, DateTimeField, FieldError),
|
||||||
|
(TimeField, Combinable.ADD, TimeField, FieldError),
|
||||||
(DurationField, Combinable.ADD, DurationField, DurationField),
|
(DurationField, Combinable.ADD, DurationField, DurationField),
|
||||||
|
# Add - different type.
|
||||||
|
(DateField, Combinable.ADD, DurationField, DateTimeField),
|
||||||
|
(DateTimeField, Combinable.ADD, DurationField, DateTimeField),
|
||||||
|
(TimeField, Combinable.ADD, DurationField, TimeField),
|
||||||
|
(DurationField, Combinable.ADD, DateField, DateTimeField),
|
||||||
|
(DurationField, Combinable.ADD, DateTimeField, DateTimeField),
|
||||||
|
(DurationField, Combinable.ADD, TimeField, TimeField),
|
||||||
# Subtract - same type.
|
# Subtract - same type.
|
||||||
|
(DateField, Combinable.SUB, DateField, DurationField),
|
||||||
|
(DateTimeField, Combinable.SUB, DateTimeField, DurationField),
|
||||||
|
(TimeField, Combinable.SUB, TimeField, DurationField),
|
||||||
(DurationField, Combinable.SUB, DurationField, DurationField),
|
(DurationField, Combinable.SUB, DurationField, DurationField),
|
||||||
# Subtract - different type.
|
# Subtract - different type.
|
||||||
|
(DateField, Combinable.SUB, DurationField, DateTimeField),
|
||||||
|
(DateTimeField, Combinable.SUB, DurationField, DateTimeField),
|
||||||
|
(TimeField, Combinable.SUB, DurationField, TimeField),
|
||||||
(DurationField, Combinable.SUB, DateField, FieldError),
|
(DurationField, Combinable.SUB, DateField, FieldError),
|
||||||
(DurationField, Combinable.SUB, DateTimeField, FieldError),
|
(DurationField, Combinable.SUB, DateTimeField, FieldError),
|
||||||
(DurationField, Combinable.SUB, DateTimeField, FieldError),
|
(DurationField, Combinable.SUB, DateTimeField, FieldError),
|
||||||
]
|
]
|
||||||
msg = "Expression contains mixed types: "
|
|
||||||
for lhs, connector, rhs, combined in tests:
|
for lhs, connector, rhs, combined in tests:
|
||||||
|
msg = (
|
||||||
|
f"Cannot infer type of {connector!r} expression involving these types: "
|
||||||
|
)
|
||||||
with self.subTest(lhs=lhs, connector=connector, rhs=rhs, combined=combined):
|
with self.subTest(lhs=lhs, connector=connector, rhs=rhs, combined=combined):
|
||||||
expr = CombinedExpression(
|
expr = CombinedExpression(
|
||||||
Expression(lhs()),
|
Expression(lhs()),
|
||||||
|
@ -2477,8 +2508,8 @@ class CombinedExpressionTests(SimpleTestCase):
|
||||||
def test_mixed_char_date_with_annotate(self):
|
def test_mixed_char_date_with_annotate(self):
|
||||||
queryset = Experiment.objects.annotate(nonsense=F("name") + F("assigned"))
|
queryset = Experiment.objects.annotate(nonsense=F("name") + F("assigned"))
|
||||||
msg = (
|
msg = (
|
||||||
"Expression contains mixed types: CharField, DateField. You must set "
|
"Cannot infer type of '+' expression involving these types: CharField, "
|
||||||
"output_field."
|
"DateField. You must set output_field."
|
||||||
)
|
)
|
||||||
with self.assertRaisesMessage(FieldError, msg):
|
with self.assertRaisesMessage(FieldError, msg):
|
||||||
list(queryset)
|
list(queryset)
|
||||||
|
|
Loading…
Reference in New Issue