Refs #33397 -- Added register_combinable_fields().

This commit is contained in:
Luke Plant 2022-03-31 08:10:22 +02:00 committed by Mariusz Felisiak
parent d7eb500338
commit 1efea11808
1 changed files with 55 additions and 10 deletions

View File

@ -2,6 +2,7 @@ import copy
import datetime
import functools
import inspect
from collections import defaultdict
from decimal import Decimal
from uuid import UUID
@ -465,16 +466,60 @@ class Expression(BaseExpression, Combinable):
return hash(self.identity)
_connector_combinators = {
connector: [
(fields.IntegerField, fields.IntegerField, fields.IntegerField),
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField),
]
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
}
# Type inference for CombinedExpression.output_field.
_connector_combinations = [
# Numeric operations - operands of same type.
{
connector: [
(fields.IntegerField, fields.IntegerField, fields.IntegerField),
(fields.FloatField, fields.FloatField, fields.FloatField),
(fields.DecimalField, fields.DecimalField, fields.DecimalField),
]
for connector in (
Combinable.ADD,
Combinable.SUB,
Combinable.MUL,
# Behavior for DIV with integer arguments follows Postgres/SQLite,
# not MySQL/Oracle.
Combinable.DIV,
)
},
# Numeric operations - operands of different type.
{
connector: [
(fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField),
]
for connector in (
Combinable.ADD,
Combinable.SUB,
Combinable.MUL,
Combinable.DIV,
)
},
]
_connector_combinators = defaultdict(list)
def register_combinable_fields(lhs, connector, rhs, result):
"""
Register combinable types:
lhs <connector> rhs -> result
e.g.
register_combinable_fields(
IntegerField, Combinable.ADD, FloatField, FloatField
)
"""
_connector_combinators[connector].append((lhs, rhs, result))
for d in _connector_combinations:
for connector, field_types in d.items():
for lhs, rhs, result in field_types:
register_combinable_fields(lhs, connector, rhs, result)
@functools.lru_cache(maxsize=128)