mirror of https://github.com/django/django.git
Refs #28643 -- Moved db function mixins to a separate module.
This commit is contained in:
parent
7f1577d1ef
commit
3d5e0f8394
|
@ -1,56 +1,35 @@
|
||||||
import math
|
import math
|
||||||
import sys
|
|
||||||
|
|
||||||
from django.db.models.expressions import Func
|
from django.db.models.expressions import Func
|
||||||
from django.db.models.fields import DecimalField, FloatField, IntegerField
|
from django.db.models.fields import FloatField, IntegerField
|
||||||
from django.db.models.functions import Cast
|
from django.db.models.functions import Cast
|
||||||
|
from django.db.models.functions.mixins import (
|
||||||
|
FixDecimalInputMixin, NumericOutputFieldMixin,
|
||||||
|
)
|
||||||
from django.db.models.lookups import Transform
|
from django.db.models.lookups import Transform
|
||||||
|
|
||||||
|
|
||||||
class DecimalInputMixin:
|
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection, **extra_context):
|
|
||||||
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
|
||||||
# following function signatures:
|
|
||||||
# - LOG(double, double)
|
|
||||||
# - MOD(double, double)
|
|
||||||
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
|
||||||
clone = self.copy()
|
|
||||||
clone.set_source_expressions([
|
|
||||||
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
|
|
||||||
else expression for expression in self.get_source_expressions()
|
|
||||||
])
|
|
||||||
return clone.as_sql(compiler, connection, **extra_context)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputFieldMixin:
|
|
||||||
|
|
||||||
def _resolve_output_field(self):
|
|
||||||
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
|
|
||||||
return DecimalField() if has_decimals else FloatField()
|
|
||||||
|
|
||||||
|
|
||||||
class Abs(Transform):
|
class Abs(Transform):
|
||||||
function = 'ABS'
|
function = 'ABS'
|
||||||
lookup_name = 'abs'
|
lookup_name = 'abs'
|
||||||
|
|
||||||
|
|
||||||
class ACos(OutputFieldMixin, Transform):
|
class ACos(NumericOutputFieldMixin, Transform):
|
||||||
function = 'ACOS'
|
function = 'ACOS'
|
||||||
lookup_name = 'acos'
|
lookup_name = 'acos'
|
||||||
|
|
||||||
|
|
||||||
class ASin(OutputFieldMixin, Transform):
|
class ASin(NumericOutputFieldMixin, Transform):
|
||||||
function = 'ASIN'
|
function = 'ASIN'
|
||||||
lookup_name = 'asin'
|
lookup_name = 'asin'
|
||||||
|
|
||||||
|
|
||||||
class ATan(OutputFieldMixin, Transform):
|
class ATan(NumericOutputFieldMixin, Transform):
|
||||||
function = 'ATAN'
|
function = 'ATAN'
|
||||||
lookup_name = 'atan'
|
lookup_name = 'atan'
|
||||||
|
|
||||||
|
|
||||||
class ATan2(OutputFieldMixin, Func):
|
class ATan2(NumericOutputFieldMixin, Func):
|
||||||
function = 'ATAN2'
|
function = 'ATAN2'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
|
||||||
|
@ -80,12 +59,12 @@ class Ceil(Transform):
|
||||||
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
|
return super().as_sql(compiler, connection, function='CEIL', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Cos(OutputFieldMixin, Transform):
|
class Cos(NumericOutputFieldMixin, Transform):
|
||||||
function = 'COS'
|
function = 'COS'
|
||||||
lookup_name = 'cos'
|
lookup_name = 'cos'
|
||||||
|
|
||||||
|
|
||||||
class Cot(OutputFieldMixin, Transform):
|
class Cot(NumericOutputFieldMixin, Transform):
|
||||||
function = 'COT'
|
function = 'COT'
|
||||||
lookup_name = 'cot'
|
lookup_name = 'cot'
|
||||||
|
|
||||||
|
@ -93,7 +72,7 @@ class Cot(OutputFieldMixin, Transform):
|
||||||
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
|
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Degrees(OutputFieldMixin, Transform):
|
class Degrees(NumericOutputFieldMixin, Transform):
|
||||||
function = 'DEGREES'
|
function = 'DEGREES'
|
||||||
lookup_name = 'degrees'
|
lookup_name = 'degrees'
|
||||||
|
|
||||||
|
@ -105,7 +84,7 @@ class Degrees(OutputFieldMixin, Transform):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Exp(OutputFieldMixin, Transform):
|
class Exp(NumericOutputFieldMixin, Transform):
|
||||||
function = 'EXP'
|
function = 'EXP'
|
||||||
lookup_name = 'exp'
|
lookup_name = 'exp'
|
||||||
|
|
||||||
|
@ -115,12 +94,12 @@ class Floor(Transform):
|
||||||
lookup_name = 'floor'
|
lookup_name = 'floor'
|
||||||
|
|
||||||
|
|
||||||
class Ln(OutputFieldMixin, Transform):
|
class Ln(NumericOutputFieldMixin, Transform):
|
||||||
function = 'LN'
|
function = 'LN'
|
||||||
lookup_name = 'ln'
|
lookup_name = 'ln'
|
||||||
|
|
||||||
|
|
||||||
class Log(DecimalInputMixin, OutputFieldMixin, Func):
|
class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||||
function = 'LOG'
|
function = 'LOG'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
|
||||||
|
@ -134,12 +113,12 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func):
|
||||||
return clone.as_sql(compiler, connection, **extra_context)
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Mod(DecimalInputMixin, OutputFieldMixin, Func):
|
class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
|
||||||
function = 'MOD'
|
function = 'MOD'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
|
||||||
|
|
||||||
class Pi(OutputFieldMixin, Func):
|
class Pi(NumericOutputFieldMixin, Func):
|
||||||
function = 'PI'
|
function = 'PI'
|
||||||
arity = 0
|
arity = 0
|
||||||
|
|
||||||
|
@ -147,12 +126,12 @@ class Pi(OutputFieldMixin, Func):
|
||||||
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
|
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Power(OutputFieldMixin, Func):
|
class Power(NumericOutputFieldMixin, Func):
|
||||||
function = 'POWER'
|
function = 'POWER'
|
||||||
arity = 2
|
arity = 2
|
||||||
|
|
||||||
|
|
||||||
class Radians(OutputFieldMixin, Transform):
|
class Radians(NumericOutputFieldMixin, Transform):
|
||||||
function = 'RADIANS'
|
function = 'RADIANS'
|
||||||
lookup_name = 'radians'
|
lookup_name = 'radians'
|
||||||
|
|
||||||
|
@ -169,16 +148,16 @@ class Round(Transform):
|
||||||
lookup_name = 'round'
|
lookup_name = 'round'
|
||||||
|
|
||||||
|
|
||||||
class Sin(OutputFieldMixin, Transform):
|
class Sin(NumericOutputFieldMixin, Transform):
|
||||||
function = 'SIN'
|
function = 'SIN'
|
||||||
lookup_name = 'sin'
|
lookup_name = 'sin'
|
||||||
|
|
||||||
|
|
||||||
class Sqrt(OutputFieldMixin, Transform):
|
class Sqrt(NumericOutputFieldMixin, Transform):
|
||||||
function = 'SQRT'
|
function = 'SQRT'
|
||||||
lookup_name = 'sqrt'
|
lookup_name = 'sqrt'
|
||||||
|
|
||||||
|
|
||||||
class Tan(OutputFieldMixin, Transform):
|
class Tan(NumericOutputFieldMixin, Transform):
|
||||||
function = 'TAN'
|
function = 'TAN'
|
||||||
lookup_name = 'tan'
|
lookup_name = 'tan'
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from django.db.models.fields import DecimalField, FloatField
|
||||||
|
from django.db.models.functions import Cast
|
||||||
|
|
||||||
|
|
||||||
|
class FixDecimalInputMixin:
|
||||||
|
|
||||||
|
def as_postgresql(self, compiler, connection, **extra_context):
|
||||||
|
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
|
||||||
|
# following function signatures:
|
||||||
|
# - LOG(double, double)
|
||||||
|
# - MOD(double, double)
|
||||||
|
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
|
||||||
|
clone = self.copy()
|
||||||
|
clone.set_source_expressions([
|
||||||
|
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
|
||||||
|
else expression for expression in self.get_source_expressions()
|
||||||
|
])
|
||||||
|
return clone.as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
|
class NumericOutputFieldMixin:
|
||||||
|
|
||||||
|
def _resolve_output_field(self):
|
||||||
|
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
|
||||||
|
return DecimalField() if has_decimals else FloatField()
|
Loading…
Reference in New Issue