Fixed #10929 -- Added default argument to aggregates.

Thanks to Simon Charette and Adam Johnson for the reviews.
This commit is contained in:
Nick Pope 2021-02-21 01:38:55 +00:00 committed by Mariusz Felisiak
parent 59942a66ce
commit 501a8db465
11 changed files with 393 additions and 64 deletions

View File

@ -18,7 +18,7 @@ class ArrayAgg(OrderableAggMixin, Aggregate):
return ArrayField(self.source_expressions[0].output_field) return ArrayField(self.source_expressions[0].output_field)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if value is None and self.default is None:
return [] return []
return value return value
@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate):
output_field = JSONField() output_field = JSONField()
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if value is None and self.default is None:
return '[]' return '[]'
return value return value
@ -63,6 +63,6 @@ class StringAgg(OrderableAggMixin, Aggregate):
super().__init__(expression, delimiter_expr, **extra) super().__init__(expression, delimiter_expr, **extra)
def convert_value(self, value, expression, connection): def convert_value(self, value, expression, connection):
if not value: if value is None and self.default is None:
return '' return ''
return value return value

View File

@ -9,10 +9,10 @@ __all__ = [
class StatAggregate(Aggregate): class StatAggregate(Aggregate):
output_field = FloatField() output_field = FloatField()
def __init__(self, y, x, output_field=None, filter=None): def __init__(self, y, x, output_field=None, filter=None, default=None):
if not x or not y: if not x or not y:
raise ValueError('Both y and x must be provided.') raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field, filter=filter) super().__init__(y, x, output_field=output_field, filter=filter, default=default)
class Corr(StatAggregate): class Corr(StatAggregate):
@ -20,9 +20,9 @@ class Corr(StatAggregate):
class CovarPop(StatAggregate): class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False, filter=None): def __init__(self, y, x, sample=False, filter=None, default=None):
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
super().__init__(y, x, filter=filter) super().__init__(y, x, filter=filter, default=default)
class RegrAvgX(StatAggregate): class RegrAvgX(StatAggregate):

View File

@ -88,6 +88,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o', 'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
}, },
}) })
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
skips.update({
'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
},
'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
},
})
if ( if (
self.connection.mysql_is_mariadb and self.connection.mysql_is_mariadb and
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2) (10, 4, 3) < self.connection.mysql_version < (10, 5, 2)

View File

@ -4,6 +4,7 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import ( from django.db.models.functions.mixins import (
FixDurationInputMixin, NumericOutputFieldMixin, FixDurationInputMixin, NumericOutputFieldMixin,
) )
@ -22,11 +23,14 @@ class Aggregate(Func):
allow_distinct = False allow_distinct = False
empty_aggregate_value = None empty_aggregate_value = None
def __init__(self, *expressions, distinct=False, filter=None, **extra): def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
if distinct and not self.allow_distinct: if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__) raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_aggregate_value is not None:
raise TypeError(f'{self.__class__.__name__} does not allow default.')
self.distinct = distinct self.distinct = distinct
self.filter = filter self.filter = filter
self.default = default
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
def get_source_fields(self): def get_source_fields(self):
@ -56,7 +60,12 @@ class Aggregate(Func):
before_resolved = self.get_source_expressions()[index] before_resolved = self.get_source_expressions()[index]
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved) name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name)) raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
return c if (default := c.default) is None:
return c
if hasattr(default, 'resolve_expression'):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
c.default = None # Reset the default argument before wrapping.
return Coalesce(c, default, output_field=c._output_field_or_none)
@property @property
def default_alias(self): def default_alias(self):

View File

@ -19,8 +19,8 @@ module. They are described in more detail in the `PostgreSQL docs
.. admonition:: Common aggregate options .. admonition:: Common aggregate options
All aggregates have the :ref:`filter <aggregate-filter>` keyword All aggregates have the :ref:`filter <aggregate-filter>` keyword argument
argument. and most also have the :ref:`default <aggregate-default>` keyword argument.
General-purpose aggregation functions General-purpose aggregation functions
===================================== =====================================
@ -28,9 +28,10 @@ General-purpose aggregation functions
``ArrayAgg`` ``ArrayAgg``
------------ ------------
.. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra) .. class:: ArrayAgg(expression, distinct=False, filter=None, default=None, ordering=(), **extra)
Returns a list of values, including nulls, concatenated into an array. Returns a list of values, including nulls, concatenated into an array, or
``default`` if there are no values.
.. attribute:: distinct .. attribute:: distinct
@ -54,26 +55,26 @@ General-purpose aggregation functions
``BitAnd`` ``BitAnd``
---------- ----------
.. class:: BitAnd(expression, filter=None, **extra) .. class:: BitAnd(expression, filter=None, default=None, **extra)
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
``None`` if all values are null. ``default`` if all values are null.
``BitOr`` ``BitOr``
--------- ---------
.. class:: BitOr(expression, filter=None, **extra) .. class:: BitOr(expression, filter=None, default=None, **extra)
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
``None`` if all values are null. ``default`` if all values are null.
``BoolAnd`` ``BoolAnd``
----------- -----------
.. class:: BoolAnd(expression, filter=None, **extra) .. class:: BoolAnd(expression, filter=None, default=None, **extra)
Returns ``True``, if all input values are true, ``None`` if all values are Returns ``True``, if all input values are true, ``default`` if all values
null or if there are no values, otherwise ``False`` . are null or if there are no values, otherwise ``False``.
Usage example:: Usage example::
@ -92,9 +93,9 @@ General-purpose aggregation functions
``BoolOr`` ``BoolOr``
---------- ----------
.. class:: BoolOr(expression, filter=None, **extra) .. class:: BoolOr(expression, filter=None, default=None, **extra)
Returns ``True`` if at least one input value is true, ``None`` if all Returns ``True`` if at least one input value is true, ``default`` if all
values are null or if there are no values, otherwise ``False``. values are null or if there are no values, otherwise ``False``.
Usage example:: Usage example::
@ -114,9 +115,10 @@ General-purpose aggregation functions
``JSONBAgg`` ``JSONBAgg``
------------ ------------
.. class:: JSONBAgg(expressions, distinct=False, filter=None, ordering=(), **extra) .. class:: JSONBAgg(expressions, distinct=False, filter=None, default=None, ordering=(), **extra)
Returns the input values as a ``JSON`` array. Returns the input values as a ``JSON`` array, or ``default`` if there are
no values.
.. attribute:: distinct .. attribute:: distinct
@ -139,10 +141,10 @@ General-purpose aggregation functions
``StringAgg`` ``StringAgg``
------------- -------------
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=()) .. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, ordering=())
Returns the input values concatenated into a string, separated by Returns the input values concatenated into a string, separated by
the ``delimiter`` string. the ``delimiter`` string, or ``default`` if there are no values.
.. attribute:: delimiter .. attribute:: delimiter
@ -174,17 +176,17 @@ field or an expression returning a numeric data. Both are required.
``Corr`` ``Corr``
-------- --------
.. class:: Corr(y, x, filter=None) .. class:: Corr(y, x, filter=None, default=None)
Returns the correlation coefficient as a ``float``, or ``None`` if there Returns the correlation coefficient as a ``float``, or ``default`` if there
aren't any matching rows. aren't any matching rows.
``CovarPop`` ``CovarPop``
------------ ------------
.. class:: CovarPop(y, x, sample=False, filter=None) .. class:: CovarPop(y, x, sample=False, filter=None, default=None)
Returns the population covariance as a ``float``, or ``None`` if there Returns the population covariance as a ``float``, or ``default`` if there
aren't any matching rows. aren't any matching rows.
Has one optional argument: Has one optional argument:
@ -198,18 +200,18 @@ field or an expression returning a numeric data. Both are required.
``RegrAvgX`` ``RegrAvgX``
------------ ------------
.. class:: RegrAvgX(y, x, filter=None) .. class:: RegrAvgX(y, x, filter=None, default=None)
Returns the average of the independent variable (``sum(x)/N``) as a Returns the average of the independent variable (``sum(x)/N``) as a
``float``, or ``None`` if there aren't any matching rows. ``float``, or ``default`` if there aren't any matching rows.
``RegrAvgY`` ``RegrAvgY``
------------ ------------
.. class:: RegrAvgY(y, x, filter=None) .. class:: RegrAvgY(y, x, filter=None, default=None)
Returns the average of the dependent variable (``sum(y)/N``) as a Returns the average of the dependent variable (``sum(y)/N``) as a
``float``, or ``None`` if there aren't any matching rows. ``float``, or ``default`` if there aren't any matching rows.
``RegrCount`` ``RegrCount``
------------- -------------
@ -219,56 +221,60 @@ field or an expression returning a numeric data. Both are required.
Returns an ``int`` of the number of input rows in which both expressions Returns an ``int`` of the number of input rows in which both expressions
are not null. are not null.
.. note::
The ``default`` argument is not supported.
``RegrIntercept`` ``RegrIntercept``
----------------- -----------------
.. class:: RegrIntercept(y, x, filter=None) .. class:: RegrIntercept(y, x, filter=None, default=None)
Returns the y-intercept of the least-squares-fit linear equation determined Returns the y-intercept of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any
matching rows. matching rows.
``RegrR2`` ``RegrR2``
---------- ----------
.. class:: RegrR2(y, x, filter=None) .. class:: RegrR2(y, x, filter=None, default=None)
Returns the square of the correlation coefficient as a ``float``, or Returns the square of the correlation coefficient as a ``float``, or
``None`` if there aren't any matching rows. ``default`` if there aren't any matching rows.
``RegrSlope`` ``RegrSlope``
------------- -------------
.. class:: RegrSlope(y, x, filter=None) .. class:: RegrSlope(y, x, filter=None, default=None)
Returns the slope of the least-squares-fit linear equation determined Returns the slope of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any
matching rows. matching rows.
``RegrSXX`` ``RegrSXX``
----------- -----------
.. class:: RegrSXX(y, x, filter=None) .. class:: RegrSXX(y, x, filter=None, default=None)
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
variable) as a ``float``, or ``None`` if there aren't any matching rows. variable) as a ``float``, or ``default`` if there aren't any matching rows.
``RegrSXY`` ``RegrSXY``
----------- -----------
.. class:: RegrSXY(y, x, filter=None) .. class:: RegrSXY(y, x, filter=None, default=None)
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
times dependent variable) as a ``float``, or ``None`` if there aren't any times dependent variable) as a ``float``, or ``default`` if there aren't
matching rows. any matching rows.
``RegrSYY`` ``RegrSYY``
----------- -----------
.. class:: RegrSYY(y, x, filter=None) .. class:: RegrSYY(y, x, filter=None, default=None)
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
variable) as a ``float``, or ``None`` if there aren't any matching rows. variable) as a ``float``, or ``default`` if there aren't any matching rows.
Usage examples Usage examples
============== ==============

View File

@ -59,7 +59,7 @@ will result in a database error.
Usage examples:: Usage examples::
>>> # Get a screen name from least to most public >>> # Get a screen name from least to most public
>>> from django.db.models import Sum, Value as V >>> from django.db.models import Sum
>>> from django.db.models.functions import Coalesce >>> from django.db.models.functions import Coalesce
>>> Author.objects.create(name='Margaret Smith', goes_by='Maggie') >>> Author.objects.create(name='Margaret Smith', goes_by='Maggie')
>>> author = Author.objects.annotate( >>> author = Author.objects.annotate(
@ -68,13 +68,18 @@ Usage examples::
Maggie Maggie
>>> # Prevent an aggregate Sum() from returning None >>> # Prevent an aggregate Sum() from returning None
>>> # The aggregate default argument uses Coalesce() under the hood.
>>> aggregated = Author.objects.aggregate( >>> aggregated = Author.objects.aggregate(
... combined_age=Coalesce(Sum('age'), V(0)), ... combined_age=Sum('age'),
... combined_age_default=Sum('age')) ... combined_age_default=Sum('age', default=0),
... combined_age_coalesce=Coalesce(Sum('age'), 0),
... )
>>> print(aggregated['combined_age']) >>> print(aggregated['combined_age'])
0
>>> print(aggregated['combined_age_default'])
None None
>>> print(aggregated['combined_age_default'])
0
>>> print(aggregated['combined_age_coalesce'])
0
.. warning:: .. warning::

View File

@ -393,7 +393,7 @@ some complex computations::
The ``Aggregate`` API is as follows: The ``Aggregate`` API is as follows:
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra) .. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra)
.. attribute:: template .. attribute:: template
@ -452,6 +452,11 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation` used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage. and :ref:`filtering-on-annotations` for example usage.
The ``default`` argument takes a value that will be passed along with the
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
specifying a value to be returned other than ``None`` when the queryset (or
grouping) contains no entries.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute. into the ``template`` attribute.
@ -459,6 +464,10 @@ into the ``template`` attribute.
Support for transforms of the field was added. Support for transforms of the field was added.
.. versionchanged:: 4.0
The ``default`` argument was added.
Creating your own Aggregate Functions Creating your own Aggregate Functions
------------------------------------- -------------------------------------

View File

@ -3540,8 +3540,10 @@ documentation to learn how to create your aggregates.
Aggregation functions return ``None`` when used with an empty Aggregation functions return ``None`` when used with an empty
``QuerySet``. For example, the ``Sum`` aggregation function returns ``None`` ``QuerySet``. For example, the ``Sum`` aggregation function returns ``None``
instead of ``0`` if the ``QuerySet`` contains no entries. An exception is instead of ``0`` if the ``QuerySet`` contains no entries. To return another
``Count``, which does return ``0`` if the ``QuerySet`` is empty. value instead, pass a value to the ``default`` argument. An exception is
``Count``, which does return ``0`` if the ``QuerySet`` is empty. ``Count``
does not support the ``default`` argument.
All aggregates have the following parameters in common: All aggregates have the following parameters in common:
@ -3578,6 +3580,16 @@ rows that are aggregated.
See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for
example usage. example usage.
.. _aggregate-default:
``default``
~~~~~~~~~~~
.. versionadded:: 4.0
An optional argument that allows specifying a value to use as a default value
when the queryset (or grouping) contains no entries.
``**extra`` ``**extra``
~~~~~~~~~~~ ~~~~~~~~~~~
@ -3587,7 +3599,7 @@ by the aggregate.
``Avg`` ``Avg``
~~~~~~~ ~~~~~~~
.. class:: Avg(expression, output_field=None, distinct=False, filter=None, **extra) .. class:: Avg(expression, output_field=None, distinct=False, filter=None, default=None, **extra)
Returns the mean value of the given expression, which must be numeric Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``. unless you specify a different ``output_field``.
@ -3623,10 +3635,14 @@ by the aggregate.
This is the SQL equivalent of ``COUNT(DISTINCT <field>)``. The default This is the SQL equivalent of ``COUNT(DISTINCT <field>)``. The default
value is ``False``. value is ``False``.
.. note::
The ``default`` argument is not supported.
``Max`` ``Max``
~~~~~~~ ~~~~~~~
.. class:: Max(expression, output_field=None, filter=None, **extra) .. class:: Max(expression, output_field=None, filter=None, default=None, **extra)
Returns the maximum value of the given expression. Returns the maximum value of the given expression.
@ -3636,7 +3652,7 @@ by the aggregate.
``Min`` ``Min``
~~~~~~~ ~~~~~~~
.. class:: Min(expression, output_field=None, filter=None, **extra) .. class:: Min(expression, output_field=None, filter=None, default=None, **extra)
Returns the minimum value of the given expression. Returns the minimum value of the given expression.
@ -3646,7 +3662,7 @@ by the aggregate.
``StdDev`` ``StdDev``
~~~~~~~~~~ ~~~~~~~~~~
.. class:: StdDev(expression, output_field=None, sample=False, filter=None, **extra) .. class:: StdDev(expression, output_field=None, sample=False, filter=None, default=None, **extra)
Returns the standard deviation of the data in the provided expression. Returns the standard deviation of the data in the provided expression.
@ -3664,7 +3680,7 @@ by the aggregate.
``Sum`` ``Sum``
~~~~~~~ ~~~~~~~
.. class:: Sum(expression, output_field=None, distinct=False, filter=None, **extra) .. class:: Sum(expression, output_field=None, distinct=False, filter=None, default=None, **extra)
Computes the sum of all values of the given expression. Computes the sum of all values of the given expression.
@ -3682,7 +3698,7 @@ by the aggregate.
``Variance`` ``Variance``
~~~~~~~~~~~~ ~~~~~~~~~~~~
.. class:: Variance(expression, output_field=None, sample=False, filter=None, **extra) .. class:: Variance(expression, output_field=None, sample=False, filter=None, default=None, **extra)
Returns the variance of the data in the provided expression. Returns the variance of the data in the provided expression.

View File

@ -288,6 +288,10 @@ Models
* :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet`` * :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet``
annotations, aggregations, and directly in filters. annotations, aggregations, and directly in filters.
* The new :ref:`default <aggregate-default>` argument for built-in aggregates
allows specifying a value to be returned when the queryset (or grouping)
contains no entries, rather than ``None``.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1,15 +1,19 @@
import datetime import datetime
import math
import re import re
from decimal import Decimal from decimal import Decimal
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connection from django.db import connection
from django.db.models import ( from django.db.models import (
Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField, Avg, Case, Count, DateField, DateTimeField, DecimalField, DurationField,
IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When, Exists, F, FloatField, IntegerField, Max, Min, OuterRef, Q, StdDev,
Subquery, Sum, TimeField, Value, Variance, When,
) )
from django.db.models.expressions import Func, RawSQL from django.db.models.expressions import Func, RawSQL
from django.db.models.functions import Coalesce, Greatest from django.db.models.functions import (
Cast, Coalesce, Greatest, Now, Pi, TruncDate, TruncHour,
)
from django.test import TestCase from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature from django.test.testcases import skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext from django.test.utils import Approximate, CaptureQueriesContext
@ -18,6 +22,20 @@ from django.utils import timezone
from .models import Author, Book, Publisher, Store from .models import Author, Book, Publisher, Store
class NowUTC(Now):
template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def as_mysql(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template='UTC_TIMESTAMP', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template="CURRENT_TIMESTAMP AT TIME ZONE 'UTC'", **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template="STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'", **extra_context)
class AggregateTestCase(TestCase): class AggregateTestCase(TestCase):
@classmethod @classmethod
@ -1402,3 +1420,190 @@ class AggregateTestCase(TestCase):
)['latest_opening'], )['latest_opening'],
datetime.datetime, datetime.datetime,
) )
def test_aggregation_default_unsupported_by_count(self):
msg = 'Count does not allow default.'
with self.assertRaisesMessage(TypeError, msg):
Count('age', default=0)
def test_aggregation_default_unset(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age'),
)
self.assertIsNone(result['value'])
def test_aggregation_default_zero(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=0),
)
self.assertEqual(result['value'], 0)
def test_aggregation_default_integer(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=21),
)
self.assertEqual(result['value'], 21)
def test_aggregation_default_expression(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=Value(5) * Value(7)),
)
self.assertEqual(result['value'], 35)
def test_aggregation_default_group_by(self):
qs = Publisher.objects.values('name').annotate(
books=Count('book'),
pages=Sum('book__pages', default=0),
).filter(books=0)
self.assertSequenceEqual(
qs,
[{'name': "Jonno's House of Books", 'books': 0, 'pages': 0}],
)
def test_aggregation_default_compound_expression(self):
# Scale rating to a percentage; default to 50% if no books published.
formula = Avg('book__rating', default=2.5) * 20.0
queryset = Publisher.objects.annotate(rating=formula).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'rating'), [
{'name': 'Apress', 'rating': 85.0},
{'name': "Jonno's House of Books", 'rating': 50.0},
{'name': 'Morgan Kaufmann', 'rating': 100.0},
{'name': 'Prentice Hall', 'rating': 80.0},
{'name': 'Sams', 'rating': 60.0},
])
def test_aggregation_default_using_time_from_python(self):
expr = Min(
'store__friday_night_closing',
filter=~Q(store__name='Amazon.com'),
default=datetime.time(17),
)
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 8.0+ & MariaDB.
expr.default = Cast(expr.default, TimeField())
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '067232959', 'oldest_store_opening': datetime.time(17)},
{'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)},
])
def test_aggregation_default_using_time_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min(
'store__friday_night_closing',
filter=~Q(store__name='Amazon.com'),
default=TruncHour(NowUTC(), output_field=TimeField()),
)
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '067232959', 'oldest_store_opening': datetime.time(now.hour)},
{'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)},
])
def test_aggregation_default_using_date_from_python(self):
expr = Min('book__pubdate', default=datetime.date(1970, 1, 1))
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 5.7+ & MariaDB.
expr.default = Cast(expr.default, DateField())
queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [
{'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)},
{'name': "Jonno's House of Books", 'earliest_pubdate': datetime.date(1970, 1, 1)},
{'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)},
{'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)},
{'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)},
])
def test_aggregation_default_using_date_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min('book__pubdate', default=TruncDate(NowUTC()))
queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [
{'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)},
{'name': "Jonno's House of Books", 'earliest_pubdate': now.date()},
{'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)},
{'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)},
{'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)},
])
def test_aggregation_default_using_datetime_from_python(self):
expr = Min(
'store__original_opening',
filter=~Q(store__name='Amazon.com'),
default=datetime.datetime(1970, 1, 1),
)
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 8.0+ & MariaDB.
expr.default = Cast(expr.default, DateTimeField())
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '067232959', 'oldest_store_opening': datetime.datetime(1970, 1, 1)},
{'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
])
def test_aggregation_default_using_datetime_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min(
'store__original_opening',
filter=~Q(store__name='Amazon.com'),
default=TruncHour(NowUTC(), output_field=DateTimeField()),
)
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '067232959', 'oldest_store_opening': now.replace(minute=0, second=0, microsecond=0, tzinfo=None)},
{'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
])
def test_aggregation_default_using_duration_from_python(self):
result = Publisher.objects.filter(num_awards__gt=3).aggregate(
value=Sum('duration', default=datetime.timedelta(0)),
)
self.assertEqual(result['value'], datetime.timedelta(0))
def test_aggregation_default_using_duration_from_database(self):
result = Publisher.objects.filter(num_awards__gt=3).aggregate(
value=Sum('duration', default=Now() - Now()),
)
self.assertEqual(result['value'], datetime.timedelta(0))
def test_aggregation_default_using_decimal_from_python(self):
result = Book.objects.filter(rating__lt=3.0).aggregate(
value=Sum('price', default=Decimal('0.00')),
)
self.assertEqual(result['value'], Decimal('0.00'))
def test_aggregation_default_using_decimal_from_database(self):
result = Book.objects.filter(rating__lt=3.0).aggregate(
value=Sum('price', default=Pi()),
)
self.assertAlmostEqual(result['value'], Decimal.from_float(math.pi), places=6)
def test_aggregation_default_passed_another_aggregate(self):
result = Book.objects.aggregate(
value=Sum('price', filter=Q(rating__lt=3.0), default=Avg('pages') / 10.0),
)
self.assertAlmostEqual(result['value'], Decimal('61.72'), places=2)

View File

@ -72,6 +72,34 @@ class TestGeneralAggregate(PostgreSQLTestCase):
) )
self.assertEqual(values, {'aggregation': expected_result}) self.assertEqual(values, {'aggregation': expected_result})
def test_default_argument(self):
AggregateTestModel.objects.all().delete()
tests = [
(ArrayAgg('char_field', default=['<empty>']), ['<empty>']),
(ArrayAgg('integer_field', default=[0]), [0]),
(ArrayAgg('boolean_field', default=[False]), [False]),
(BitAnd('integer_field', default=0), 0),
(BitOr('integer_field', default=0), 0),
(BoolAnd('boolean_field', default=False), False),
(BoolOr('boolean_field', default=False), False),
(JSONBAgg('integer_field', default=Value('["<empty>"]')), ['<empty>']),
(StringAgg('char_field', delimiter=';', default=Value('<empty>')), '<empty>'),
]
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
# Empty result with non-execution optimization.
with self.assertNumQueries(0):
values = AggregateTestModel.objects.none().aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = AggregateTestModel.objects.aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
def test_array_agg_charfield(self): def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
@ -515,6 +543,37 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
) )
self.assertEqual(values, {'aggregation': expected_result}) self.assertEqual(values, {'aggregation': expected_result})
def test_default_argument(self):
StatTestModel.objects.all().delete()
tests = [
(Corr(y='int2', x='int1', default=0), 0),
(CovarPop(y='int2', x='int1', default=0), 0),
(CovarPop(y='int2', x='int1', sample=True, default=0), 0),
(RegrAvgX(y='int2', x='int1', default=0), 0),
(RegrAvgY(y='int2', x='int1', default=0), 0),
# RegrCount() doesn't support the default argument.
(RegrIntercept(y='int2', x='int1', default=0), 0),
(RegrR2(y='int2', x='int1', default=0), 0),
(RegrSlope(y='int2', x='int1', default=0), 0),
(RegrSXX(y='int2', x='int1', default=0), 0),
(RegrSXY(y='int2', x='int1', default=0), 0),
(RegrSYY(y='int2', x='int1', default=0), 0),
]
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
# Empty result with non-execution optimization.
with self.assertNumQueries(0):
values = StatTestModel.objects.none().aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = StatTestModel.objects.aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
def test_corr_general(self): def test_corr_general(self):
values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1')) values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1'))
self.assertEqual(values, {'corr': -1.0}) self.assertEqual(values, {'corr': -1.0})
@ -539,6 +598,11 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1')) values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1'))
self.assertEqual(values, {'regrcount': 3}) self.assertEqual(values, {'regrcount': 3})
def test_regr_count_default(self):
msg = 'RegrCount does not allow default.'
with self.assertRaisesMessage(TypeError, msg):
RegrCount(y='int2', x='int1', default=0)
def test_regr_intercept_general(self): def test_regr_intercept_general(self):
values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1')) values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1'))
self.assertEqual(values, {'regrintercept': 4}) self.assertEqual(values, {'regrintercept': 4})