Refs #28643 -- Moved db function mixins to a separate module.

This commit is contained in:
Nick Pope 2018-12-01 23:43:27 +00:00 committed by Tim Graham
parent 7f1577d1ef
commit 3d5e0f8394
2 changed files with 48 additions and 42 deletions

View File

@ -1,56 +1,35 @@
import math import math
import sys
from django.db.models.expressions import Func from django.db.models.expressions import Func
from django.db.models.fields import DecimalField, FloatField, IntegerField from django.db.models.fields import FloatField, IntegerField
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.db.models.functions.mixins import (
FixDecimalInputMixin, NumericOutputFieldMixin,
)
from django.db.models.lookups import Transform from django.db.models.lookups import Transform
class DecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
# - LOG(double, double)
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions([
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
else expression for expression in self.get_source_expressions()
])
return clone.as_sql(compiler, connection, **extra_context)
class OutputFieldMixin:
def _resolve_output_field(self):
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
return DecimalField() if has_decimals else FloatField()
class Abs(Transform): class Abs(Transform):
function = 'ABS' function = 'ABS'
lookup_name = 'abs' lookup_name = 'abs'
class ACos(OutputFieldMixin, Transform): class ACos(NumericOutputFieldMixin, Transform):
function = 'ACOS' function = 'ACOS'
lookup_name = 'acos' lookup_name = 'acos'
class ASin(OutputFieldMixin, Transform): class ASin(NumericOutputFieldMixin, Transform):
function = 'ASIN' function = 'ASIN'
lookup_name = 'asin' lookup_name = 'asin'
class ATan(OutputFieldMixin, Transform): class ATan(NumericOutputFieldMixin, Transform):
function = 'ATAN' function = 'ATAN'
lookup_name = 'atan' lookup_name = 'atan'
class ATan2(OutputFieldMixin, Func): class ATan2(NumericOutputFieldMixin, Func):
function = 'ATAN2' function = 'ATAN2'
arity = 2 arity = 2
@ -80,12 +59,12 @@ class Ceil(Transform):
return super().as_sql(compiler, connection, function='CEIL', **extra_context) return super().as_sql(compiler, connection, function='CEIL', **extra_context)
class Cos(OutputFieldMixin, Transform): class Cos(NumericOutputFieldMixin, Transform):
function = 'COS' function = 'COS'
lookup_name = 'cos' lookup_name = 'cos'
class Cot(OutputFieldMixin, Transform): class Cot(NumericOutputFieldMixin, Transform):
function = 'COT' function = 'COT'
lookup_name = 'cot' lookup_name = 'cot'
@ -93,7 +72,7 @@ class Cot(OutputFieldMixin, Transform):
return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context) return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
class Degrees(OutputFieldMixin, Transform): class Degrees(NumericOutputFieldMixin, Transform):
function = 'DEGREES' function = 'DEGREES'
lookup_name = 'degrees' lookup_name = 'degrees'
@ -105,7 +84,7 @@ class Degrees(OutputFieldMixin, Transform):
) )
class Exp(OutputFieldMixin, Transform): class Exp(NumericOutputFieldMixin, Transform):
function = 'EXP' function = 'EXP'
lookup_name = 'exp' lookup_name = 'exp'
@ -115,12 +94,12 @@ class Floor(Transform):
lookup_name = 'floor' lookup_name = 'floor'
class Ln(OutputFieldMixin, Transform): class Ln(NumericOutputFieldMixin, Transform):
function = 'LN' function = 'LN'
lookup_name = 'ln' lookup_name = 'ln'
class Log(DecimalInputMixin, OutputFieldMixin, Func): class Log(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'LOG' function = 'LOG'
arity = 2 arity = 2
@ -134,12 +113,12 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func):
return clone.as_sql(compiler, connection, **extra_context) return clone.as_sql(compiler, connection, **extra_context)
class Mod(DecimalInputMixin, OutputFieldMixin, Func): class Mod(FixDecimalInputMixin, NumericOutputFieldMixin, Func):
function = 'MOD' function = 'MOD'
arity = 2 arity = 2
class Pi(OutputFieldMixin, Func): class Pi(NumericOutputFieldMixin, Func):
function = 'PI' function = 'PI'
arity = 0 arity = 0
@ -147,12 +126,12 @@ class Pi(OutputFieldMixin, Func):
return super().as_sql(compiler, connection, template=str(math.pi), **extra_context) return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
class Power(OutputFieldMixin, Func): class Power(NumericOutputFieldMixin, Func):
function = 'POWER' function = 'POWER'
arity = 2 arity = 2
class Radians(OutputFieldMixin, Transform): class Radians(NumericOutputFieldMixin, Transform):
function = 'RADIANS' function = 'RADIANS'
lookup_name = 'radians' lookup_name = 'radians'
@ -169,16 +148,16 @@ class Round(Transform):
lookup_name = 'round' lookup_name = 'round'
class Sin(OutputFieldMixin, Transform): class Sin(NumericOutputFieldMixin, Transform):
function = 'SIN' function = 'SIN'
lookup_name = 'sin' lookup_name = 'sin'
class Sqrt(OutputFieldMixin, Transform): class Sqrt(NumericOutputFieldMixin, Transform):
function = 'SQRT' function = 'SQRT'
lookup_name = 'sqrt' lookup_name = 'sqrt'
class Tan(OutputFieldMixin, Transform): class Tan(NumericOutputFieldMixin, Transform):
function = 'TAN' function = 'TAN'
lookup_name = 'tan' lookup_name = 'tan'

View File

@ -0,0 +1,27 @@
import sys
from django.db.models.fields import DecimalField, FloatField
from django.db.models.functions import Cast
class FixDecimalInputMixin:
def as_postgresql(self, compiler, connection, **extra_context):
# Cast FloatField to DecimalField as PostgreSQL doesn't support the
# following function signatures:
# - LOG(double, double)
# - MOD(double, double)
output_field = DecimalField(decimal_places=sys.float_info.dig, max_digits=1000)
clone = self.copy()
clone.set_source_expressions([
Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
else expression for expression in self.get_source_expressions()
])
return clone.as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):
has_decimals = any(isinstance(s.output_field, DecimalField) for s in self.get_source_expressions())
return DecimalField() if has_decimals else FloatField()