Refs #28643 -- Extracted DurationField logic for Avg() and Sum() into mixin.

Also addresses Sum() not handling the filter option correctly.
This commit is contained in:
Nick Pope 2018-12-01 23:49:38 +00:00 committed by Tim Graham
parent 6d4efa8e6a
commit 846624ed08
2 changed files with 24 additions and 33 deletions

View File

@ -4,7 +4,9 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.mixins import NumericOutputFieldMixin
from django.db.models.functions.mixins import (
FixDurationInputMixin, NumericOutputFieldMixin,
)
__all__ = [
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@ -94,25 +96,10 @@ class Aggregate(Func):
return options
class Avg(NumericOutputFieldMixin, Aggregate):
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
function = 'AVG'
name = 'Avg'
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
)
return super().as_sql(compiler, connection, **extra_context)
class Count(Aggregate):
function = 'COUNT'
@ -152,25 +139,10 @@ class StdDev(NumericOutputFieldMixin, Aggregate):
return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'}
class Sum(Aggregate):
class Sum(FixDurationInputMixin, Aggregate):
function = 'SUM'
name = 'Sum'
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s as SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(Sum(IntervalToSeconds(expression)))
)
return super().as_sql(compiler, connection, **extra_context)
class Variance(NumericOutputFieldMixin, Aggregate):
name = 'Variance'

View File

@ -20,6 +20,25 @@ class FixDecimalInputMixin:
return clone.as_sql(compiler, connection, **extra_context)
class FixDurationInputMixin:
def as_mysql(self, compiler, connection, **extra_context):
sql, params = super().as_sql(compiler, connection, **extra_context)
if self.output_field.get_internal_type() == 'DurationField':
sql = 'CAST(%s AS SIGNED)' % sql
return sql, params
def as_oracle(self, compiler, connection, **extra_context):
if self.output_field.get_internal_type() == 'DurationField':
expression = self.get_source_expressions()[0]
options = self._get_repr_options()
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
return compiler.compile(
SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options))
)
return super().as_sql(compiler, connection, **extra_context)
class NumericOutputFieldMixin:
def _resolve_output_field(self):