Refs #28643 -- Extracted DurationField logic for Avg() and Sum() into mixin.

Also addresses Sum() not handling the filter option correctly.
This commit is contained in:
Nick Pope 2018-12-01 23:49:38 +00:00 committed by Tim Graham
parent 6d4efa8e6a
commit 846624ed08
2 changed files with 24 additions and 33 deletions

View File

@ -4,7 +4,9 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField from django.db.models.fields import IntegerField
from django.db.models.functions.mixins import NumericOutputFieldMixin from django.db.models.functions.mixins import (
FixDurationInputMixin, NumericOutputFieldMixin,
)
__all__ = [ __all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@ -94,25 +96,10 @@ class Aggregate(Func):
return options return options
class Avg(NumericOutputFieldMixin, Aggregate): class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = 'AVG' function = 'AVG'
name = 'Avg' name = 'Avg'
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
)
return super().as_sql(compiler, connection, **extra_context)
class Count(Aggregate): class Count(Aggregate):
function = 'COUNT' function = 'COUNT'
@ -152,25 +139,10 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'} return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
class Sum(Aggregate): class Sum(FixDurationInputMixin, Aggregate):
function = 'SUM' function = 'SUM'
name = 'Sum' name = 'Sum'
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Sum(IntervalToSeconds(expression)))
)
return super().as_sql(compiler, connection, **extra_context)
class Variance(NumericOutputFieldMixin, Aggregate): class Variance(NumericOutputFieldMixin, Aggregate):
name = 'Variance' name = 'Variance'

View File

@ -20,6 +20,25 @@ class FixDecimalInputMixin:
return clone.as_sql(compiler, connection, **extra_context) return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s AS SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin: class NumericOutputFieldMixin:
def _resolve_output_field(self): def _resolve_output_field(self):