Refs #33397 -- Added register_combinable_fields().

This commit is contained in:
Luke Plant 2022-03-31 08:10:22 +02:00 committed by Mariusz Felisiak
parent d7eb500338
commit 1efea11808
1 changed files with 55 additions and 10 deletions

View File

@ -2,6 +2,7 @@ import copy
import datetime import datetime
import functools import functools
import inspect import inspect
from collections import defaultdict
from decimal import Decimal from decimal import Decimal
from uuid import UUID from uuid import UUID
@ -465,16 +466,60 @@ class Expression(BaseExpression, Combinable):
return hash(self.identity) return hash(self.identity)
_connector_combinators = { # Type inference for CombinedExpression.output_field.
_connector_combinations = [
# Numeric operations - operands of same type.
{
connector: [ connector: [
(fields.IntegerField, fields.IntegerField, fields.IntegerField), (fields.IntegerField, fields.IntegerField, fields.IntegerField),
(fields.FloatField, fields.FloatField, fields.FloatField),
(fields.DecimalField, fields.DecimalField, fields.DecimalField),
]
for connector in (
Combinable.ADD,
Combinable.SUB,
Combinable.MUL,
# Behavior for DIV with integer arguments follows Postgres/SQLite,
# not MySQL/Oracle.
Combinable.DIV,
)
},
# Numeric operations - operands of different type.
{
connector: [
(fields.IntegerField, fields.DecimalField, fields.DecimalField), (fields.IntegerField, fields.DecimalField, fields.DecimalField),
(fields.DecimalField, fields.IntegerField, fields.DecimalField), (fields.DecimalField, fields.IntegerField, fields.DecimalField),
(fields.IntegerField, fields.FloatField, fields.FloatField), (fields.IntegerField, fields.FloatField, fields.FloatField),
(fields.FloatField, fields.IntegerField, fields.FloatField), (fields.FloatField, fields.IntegerField, fields.FloatField),
] ]
for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV) for connector in (
} Combinable.ADD,
Combinable.SUB,
Combinable.MUL,
Combinable.DIV,
)
},
]
_connector_combinators = defaultdict(list)
def register_combinable_fields(lhs, connector, rhs, result):
"""
Register combinable types:
lhs <connector> rhs -> result
e.g.
register_combinable_fields(
IntegerField, Combinable.ADD, FloatField, FloatField
)
"""
_connector_combinators[connector].append((lhs, rhs, result))
for d in _connector_combinations:
for connector, field_types in d.items():
for lhs, rhs, result in field_types:
register_combinable_fields(lhs, connector, rhs, result)
@functools.lru_cache(maxsize=128) @functools.lru_cache(maxsize=128)