From 846624ed0858aec0e51baebaa5b397e135c6d1dc Mon Sep 17 00:00:00 2001 From: Nick Pope Date: Sat, 1 Dec 2018 23:49:38 +0000 Subject: [PATCH] Refs #28643 -- Extracted DurationField logic for Avg() and Sum() into mixin. Also addresses Sum() not handling the filter option correctly. --- django/db/models/aggregates.py | 38 ++++------------------------ django/db/models/functions/mixins.py | 19 ++++++++++++++ 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 1b0f9d98af..ac0b62d0bf 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -4,7 +4,9 @@ Classes to represent the definitions of aggregate functions. from django.core.exceptions import FieldError from django.db.models.expressions import Case, Func, Star, When from django.db.models.fields import IntegerField -from django.db.models.functions.mixins import NumericOutputFieldMixin +from django.db.models.functions.mixins import ( + FixDurationInputMixin, NumericOutputFieldMixin, +) __all__ = [ 'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance', @@ -94,25 +96,10 @@ class Aggregate(Func): return options -class Avg(NumericOutputFieldMixin, Aggregate): +class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate): function = 'AVG' name = 'Avg' - def as_mysql(self, compiler, connection, **extra_context): - sql, params = super().as_sql(compiler, connection, **extra_context) - if self.output_field.get_internal_type() == 'DurationField': - sql = 'CAST(%s as SIGNED)' % sql - return sql, params - - def as_oracle(self, compiler, connection, **extra_context): - if self.output_field.get_internal_type() == 'DurationField': - expression = self.get_source_expressions()[0] - from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval - return compiler.compile( - SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter)) - ) - return super().as_sql(compiler, connection, **extra_context) - class Count(Aggregate): function = 'COUNT' @@ -152,25 +139,10 @@ class StdDev(NumericOutputFieldMixin, Aggregate): return {**super()._get_repr_options(), 'sample': self.function == 'STDDEV_SAMP'} -class Sum(Aggregate): +class Sum(FixDurationInputMixin, Aggregate): function = 'SUM' name = 'Sum' - def as_mysql(self, compiler, connection, **extra_context): - sql, params = super().as_sql(compiler, connection, **extra_context) - if self.output_field.get_internal_type() == 'DurationField': - sql = 'CAST(%s as SIGNED)' % sql - return sql, params - - def as_oracle(self, compiler, connection, **extra_context): - if self.output_field.get_internal_type() == 'DurationField': - expression = self.get_source_expressions()[0] - from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval - return compiler.compile( - SecondsToInterval(Sum(IntervalToSeconds(expression))) - ) - return super().as_sql(compiler, connection, **extra_context) - class Variance(NumericOutputFieldMixin, Aggregate): name = 'Variance' diff --git a/django/db/models/functions/mixins.py b/django/db/models/functions/mixins.py index 9b46987788..8486ddb005 100644 --- a/django/db/models/functions/mixins.py +++ b/django/db/models/functions/mixins.py @@ -20,6 +20,25 @@ class FixDecimalInputMixin: return clone.as_sql(compiler, connection, **extra_context) +class FixDurationInputMixin: + + def as_mysql(self, compiler, connection, **extra_context): + sql, params = super().as_sql(compiler, connection, **extra_context) + if self.output_field.get_internal_type() == 'DurationField': + sql = 'CAST(%s AS SIGNED)' % sql + return sql, params + + def as_oracle(self, compiler, connection, **extra_context): + if self.output_field.get_internal_type() == 'DurationField': + expression = self.get_source_expressions()[0] + options = self._get_repr_options() + from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval + return compiler.compile( + SecondsToInterval(self.__class__(IntervalToSeconds(expression), **options)) + ) + return super().as_sql(compiler, connection, **extra_context) + + class NumericOutputFieldMixin: def _resolve_output_field(self):