diff --git a/django/db/models/functions/math.py b/django/db/models/functions/math.py index c8760652b4..43cbc17a1e 100644 --- a/django/db/models/functions/math.py +++ b/django/db/models/functions/math.py @@ -1,56 +1,35 @@ import math -import sys from django.db.models.expressions import Func -from django.db.models.fields import DecimalField, FloatField, IntegerField +from django.db.models.fields import FloatField, IntegerField from django.db.models.functions import Cast +from django.db.models.functions.mixins import ( + FixDecimalInputMixin, NumericOutputFieldMixin, +) from django.db.models.lookups import Transform -class DecimalInputMixin: - - def as_postgresql(self, compiler, connection, **extra_context): - # Cast FloatField to DecimalField as PostgreSQL doesn't support the - # following function signatures: - # - LOG(double, double) - # - MOD(double, double) - output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000) - clone = self.copy() - clone.set_source_expressions([ - Cast(expression, output_field) if isinstance(expression.output_field, FloatField) - else expression for expression in self.get_source_expressions() - ]) - return clone.as_sql(compiler, connection, **extra_context) - - -class OutputFieldMixin: - - def _resolve_output_field(self): - has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions()) - return DecimalField() if has_decimals else FloatField() - - class Abs(Transform): function = 'ABS' lookup_name = 'abs' -class ACos(OutputFieldMixin, Transform): +class ACos(NumericOutputFieldMixin, Transform): function = 'ACOS' lookup_name = 'acos' -class ASin(OutputFieldMixin, Transform): +class ASin(NumericOutputFieldMixin, Transform): function = 'ASIN' lookup_name = 'asin' -class ATan(OutputFieldMixin, Transform): +class ATan(NumericOutputFieldMixin, Transform): function = 'ATAN' lookup_name = 'atan' -class ATan2(OutputFieldMixin, Func): +class ATan2(NumericOutputFieldMixin, Func): function = 'ATAN2' arity = 2 @@ -80,12 +59,12 @@ class Ceil(Transform): return super().as_sql(compiler, connection, function='CEIL', **extra_context) -class Cos(OutputFieldMixin, Transform): +class Cos(NumericOutputFieldMixin, Transform): function = 'COS' lookup_name = 'cos' -class Cot(OutputFieldMixin, Transform): +class Cot(NumericOutputFieldMixin, Transform): function = 'COT' lookup_name = 'cot' @@ -93,7 +72,7 @@ class Cot(OutputFieldMixin, Transform): return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context) -class Degrees(OutputFieldMixin, Transform): +class Degrees(NumericOutputFieldMixin, Transform): function = 'DEGREES' lookup_name = 'degrees' @@ -105,7 +84,7 @@ class Degrees(OutputFieldMixin, Transform): ) -class Exp(OutputFieldMixin, Transform): +class Exp(NumericOutputFieldMixin, Transform): function = 'EXP' lookup_name = 'exp' @@ -115,12 +94,12 @@ class Floor(Transform): lookup_name = 'floor' -class Ln(OutputFieldMixin, Transform): +class Ln(NumericOutputFieldMixin, Transform): function = 'LN' lookup_name = 'ln' -class Log(DecimalInputMixin, OutputFieldMixin, Func): +class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func): function = 'LOG' arity = 2 @@ -134,12 +113,12 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func): return clone.as_sql(compiler, connection, **extra_context) -class Mod(DecimalInputMixin, OutputFieldMixin, Func): +class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func): function = 'MOD' arity = 2 -class Pi(OutputFieldMixin, Func): +class Pi(NumericOutputFieldMixin, Func): function = 'PI' arity = 0 @@ -147,12 +126,12 @@ class Pi(OutputFieldMixin, Func): return super().as_sql(compiler, connection, template=str(math.pi), **extra_context) -class Power(OutputFieldMixin, Func): +class Power(NumericOutputFieldMixin, Func): function = 'POWER' arity = 2 -class Radians(OutputFieldMixin, Transform): +class Radians(NumericOutputFieldMixin, Transform): function = 'RADIANS' lookup_name = 'radians' @@ -169,16 +148,16 @@ class Round(Transform): lookup_name = 'round' -class Sin(OutputFieldMixin, Transform): +class Sin(NumericOutputFieldMixin, Transform): function = 'SIN' lookup_name = 'sin' -class Sqrt(OutputFieldMixin, Transform): +class Sqrt(NumericOutputFieldMixin, Transform): function = 'SQRT' lookup_name = 'sqrt' -class Tan(OutputFieldMixin, Transform): +class Tan(NumericOutputFieldMixin, Transform): function = 'TAN' lookup_name = 'tan' diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py new file mode 100644 index 0000000000..1bf3d6cbd0 --- /dev/null +++ b/django/db/models/functions/mixins.py @@ -0,0 +1,27 @@ +import sys + +from django.db.models.fields import DecimalField, FloatField +from django.db.models.functions import Cast + + +class FixDecimalInputMixin: + + def as_postgresql(self, compiler, connection, **extra_context): + # Cast FloatField to DecimalField as PostgreSQL doesn't support the + # following function signatures: + # - LOG(double, double) + # - MOD(double, double) + output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000) + clone = self.copy() + clone.set_source_expressions([ + Cast(expression, output_field) if isinstance(expression.output_field, FloatField) + else expression for expression in self.get_source_expressions() + ]) + return clone.as_sql(compiler, connection, **extra_context) + + +class NumericOutputFieldMixin: + + def _resolve_output_field(self): + has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions()) + return DecimalField() if has_decimals else FloatField()