Fixed #10929 -- Added default argument to aggregates.

Thanks to Simon Charette and Adam Johnson for the reviews.
This commit is contained in:
Nick Pope 2021-02-21 01:38:55 +00:00 committed by Mariusz Felisiak
parent 59942a66ce
commit 501a8db465
11 changed files with 393 additions and 64 deletions

View File

@ -18,7 +18,7 @@ class ArrayAgg(OrderableAggMixin, Aggregate):
return ArrayField(self.source_expressions[0].output_field)
def convert_value(self, value, expression, connection):
if not value:
if value is None and self.default is None:
return []
return value
@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate):
output_field = JSONField()
def convert_value(self, value, expression, connection):
if not value:
if value is None and self.default is None:
return '[]'
return value
@ -63,6 +63,6 @@ class StringAgg(OrderableAggMixin, Aggregate):
super().__init__(expression, delimiter_expr, **extra)
def convert_value(self, value, expression, connection):
if not value:
if value is None and self.default is None:
return ''
return value

View File

@ -9,10 +9,10 @@ __all__ = [
class StatAggregate(Aggregate):
output_field = FloatField()
def __init__(self, y, x, output_field=None, filter=None):
def __init__(self, y, x, output_field=None, filter=None, default=None):
if not x or not y:
raise ValueError('Both y and x must be provided.')
super().__init__(y, x, output_field=output_field, filter=filter)
super().__init__(y, x, output_field=output_field, filter=filter, default=default)
class Corr(StatAggregate):
@ -20,9 +20,9 @@ class Corr(StatAggregate):
class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False, filter=None):
def __init__(self, y, x, sample=False, filter=None, default=None):
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
super().__init__(y, x, filter=filter)
super().__init__(y, x, filter=filter, default=default)
class RegrAvgX(StatAggregate):

View File

@ -88,6 +88,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o',
},
})
if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,):
skips.update({
'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python',
},
'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': {
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database',
'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database',
},
})
if (
self.connection.mysql_is_mariadb and
(10, 4, 3) < self.connection.mysql_version < (10, 5, 2)

View File

@ -4,6 +4,7 @@ Classes to represent the definitions of aggregate functions.
from django.core.exceptions import FieldError
from django.db.models.expressions import Case, Func, Star, When
from django.db.models.fields import IntegerField
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.mixins import (
FixDurationInputMixin, NumericOutputFieldMixin,
)
@ -22,11 +23,14 @@ class Aggregate(Func):
allow_distinct = False
empty_aggregate_value = None
def __init__(self, *expressions, distinct=False, filter=None, **extra):
def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra):
if distinct and not self.allow_distinct:
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
if default is not None and self.empty_aggregate_value is not None:
raise TypeError(f'{self.__class__.__name__} does not allow default.')
self.distinct = distinct
self.filter = filter
self.default = default
super().__init__(*expressions, **extra)
def get_source_fields(self):
@ -56,7 +60,12 @@ class Aggregate(Func):
before_resolved = self.get_source_expressions()[index]
name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved)
raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name))
return c
if (default := c.default) is None:
return c
if hasattr(default, 'resolve_expression'):
default = default.resolve_expression(query, allow_joins, reuse, summarize)
c.default = None # Reset the default argument before wrapping.
return Coalesce(c, default, output_field=c._output_field_or_none)
@property
def default_alias(self):

View File

@ -19,8 +19,8 @@ module. They are described in more detail in the `PostgreSQL docs
.. admonition:: Common aggregate options
All aggregates have the :ref:`filter <aggregate-filter>` keyword
argument.
All aggregates have the :ref:`filter <aggregate-filter>` keyword argument
and most also have the :ref:`default <aggregate-default>` keyword argument.
General-purpose aggregation functions
=====================================
@ -28,9 +28,10 @@ General-purpose aggregation functions
``ArrayAgg``
------------
.. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra)
.. class:: ArrayAgg(expression, distinct=False, filter=None, default=None, ordering=(), **extra)
Returns a list of values, including nulls, concatenated into an array.
Returns a list of values, including nulls, concatenated into an array, or
``default`` if there are no values.
.. attribute:: distinct
@ -54,26 +55,26 @@ General-purpose aggregation functions
``BitAnd``
----------
.. class:: BitAnd(expression, filter=None, **extra)
.. class:: BitAnd(expression, filter=None, default=None, **extra)
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
``None`` if all values are null.
``default`` if all values are null.
``BitOr``
---------
.. class:: BitOr(expression, filter=None, **extra)
.. class:: BitOr(expression, filter=None, default=None, **extra)
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
``None`` if all values are null.
``default`` if all values are null.
``BoolAnd``
-----------
.. class:: BoolAnd(expression, filter=None, **extra)
.. class:: BoolAnd(expression, filter=None, default=None, **extra)
Returns ``True``, if all input values are true, ``None`` if all values are
null or if there are no values, otherwise ``False`` .
Returns ``True``, if all input values are true, ``default`` if all values
are null or if there are no values, otherwise ``False``.
Usage example::
@ -92,9 +93,9 @@ General-purpose aggregation functions
``BoolOr``
----------
.. class:: BoolOr(expression, filter=None, **extra)
.. class:: BoolOr(expression, filter=None, default=None, **extra)
Returns ``True`` if at least one input value is true, ``None`` if all
Returns ``True`` if at least one input value is true, ``default`` if all
values are null or if there are no values, otherwise ``False``.
Usage example::
@ -114,9 +115,10 @@ General-purpose aggregation functions
``JSONBAgg``
------------
.. class:: JSONBAgg(expressions, distinct=False, filter=None, ordering=(), **extra)
.. class:: JSONBAgg(expressions, distinct=False, filter=None, default=None, ordering=(), **extra)
Returns the input values as a ``JSON`` array.
Returns the input values as a ``JSON`` array, or ``default`` if there are
no values.
.. attribute:: distinct
@ -139,10 +141,10 @@ General-purpose aggregation functions
``StringAgg``
-------------
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=())
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, ordering=())
Returns the input values concatenated into a string, separated by
the ``delimiter`` string.
the ``delimiter`` string, or ``default`` if there are no values.
.. attribute:: delimiter
@ -174,17 +176,17 @@ field or an expression returning a numeric data. Both are required.
``Corr``
--------
.. class:: Corr(y, x, filter=None)
.. class:: Corr(y, x, filter=None, default=None)
Returns the correlation coefficient as a ``float``, or ``None`` if there
Returns the correlation coefficient as a ``float``, or ``default`` if there
aren't any matching rows.
``CovarPop``
------------
.. class:: CovarPop(y, x, sample=False, filter=None)
.. class:: CovarPop(y, x, sample=False, filter=None, default=None)
Returns the population covariance as a ``float``, or ``None`` if there
Returns the population covariance as a ``float``, or ``default`` if there
aren't any matching rows.
Has one optional argument:
@ -198,18 +200,18 @@ field or an expression returning a numeric data. Both are required.
``RegrAvgX``
------------
.. class:: RegrAvgX(y, x, filter=None)
.. class:: RegrAvgX(y, x, filter=None, default=None)
Returns the average of the independent variable (``sum(x)/N``) as a
``float``, or ``None`` if there aren't any matching rows.
``float``, or ``default`` if there aren't any matching rows.
``RegrAvgY``
------------
.. class:: RegrAvgY(y, x, filter=None)
.. class:: RegrAvgY(y, x, filter=None, default=None)
Returns the average of the dependent variable (``sum(y)/N``) as a
``float``, or ``None`` if there aren't any matching rows.
``float``, or ``default`` if there aren't any matching rows.
``RegrCount``
-------------
@ -219,56 +221,60 @@ field or an expression returning a numeric data. Both are required.
Returns an ``int`` of the number of input rows in which both expressions
are not null.
.. note::
The ``default`` argument is not supported.
``RegrIntercept``
-----------------
.. class:: RegrIntercept(y, x, filter=None)
.. class:: RegrIntercept(y, x, filter=None, default=None)
Returns the y-intercept of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any
matching rows.
``RegrR2``
----------
.. class:: RegrR2(y, x, filter=None)
.. class:: RegrR2(y, x, filter=None, default=None)
Returns the square of the correlation coefficient as a ``float``, or
``None`` if there aren't any matching rows.
``default`` if there aren't any matching rows.
``RegrSlope``
-------------
.. class:: RegrSlope(y, x, filter=None)
.. class:: RegrSlope(y, x, filter=None, default=None)
Returns the slope of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any
matching rows.
``RegrSXX``
-----------
.. class:: RegrSXX(y, x, filter=None)
.. class:: RegrSXX(y, x, filter=None, default=None)
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
variable) as a ``float``, or ``None`` if there aren't any matching rows.
variable) as a ``float``, or ``default`` if there aren't any matching rows.
``RegrSXY``
-----------
.. class:: RegrSXY(y, x, filter=None)
.. class:: RegrSXY(y, x, filter=None, default=None)
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
times dependent variable) as a ``float``, or ``None`` if there aren't any
matching rows.
times dependent variable) as a ``float``, or ``default`` if there aren't
any matching rows.
``RegrSYY``
-----------
.. class:: RegrSYY(y, x, filter=None)
.. class:: RegrSYY(y, x, filter=None, default=None)
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
variable) as a ``float``, or ``None`` if there aren't any matching rows.
variable) as a ``float``, or ``default`` if there aren't any matching rows.
Usage examples
==============

View File

@ -59,7 +59,7 @@ will result in a database error.
Usage examples::
>>> # Get a screen name from least to most public
>>> from django.db.models import Sum, Value as V
>>> from django.db.models import Sum
>>> from django.db.models.functions import Coalesce
>>> Author.objects.create(name='Margaret Smith', goes_by='Maggie')
>>> author = Author.objects.annotate(
@ -68,13 +68,18 @@ Usage examples::
Maggie
>>> # Prevent an aggregate Sum() from returning None
>>> # The aggregate default argument uses Coalesce() under the hood.
>>> aggregated = Author.objects.aggregate(
... combined_age=Coalesce(Sum('age'), V(0)),
... combined_age_default=Sum('age'))
... combined_age=Sum('age'),
... combined_age_default=Sum('age', default=0),
... combined_age_coalesce=Coalesce(Sum('age'), 0),
... )
>>> print(aggregated['combined_age'])
0
>>> print(aggregated['combined_age_default'])
None
>>> print(aggregated['combined_age_default'])
0
>>> print(aggregated['combined_age_coalesce'])
0
.. warning::

View File

@ -393,7 +393,7 @@ some complex computations::
The ``Aggregate`` API is as follows:
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra)
.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra)
.. attribute:: template
@ -452,6 +452,11 @@ The ``filter`` argument takes a :class:`Q object <django.db.models.Q>` that's
used to filter the rows that are aggregated. See :ref:`conditional-aggregation`
and :ref:`filtering-on-annotations` for example usage.
The ``default`` argument takes a value that will be passed along with the
aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for
specifying a value to be returned other than ``None`` when the queryset (or
grouping) contains no entries.
The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated
into the ``template`` attribute.
@ -459,6 +464,10 @@ into the ``template`` attribute.
Support for transforms of the field was added.
.. versionchanged:: 4.0
The ``default`` argument was added.
Creating your own Aggregate Functions
-------------------------------------

View File

@ -3540,8 +3540,10 @@ documentation to learn how to create your aggregates.
Aggregation functions return ``None`` when used with an empty
``QuerySet``. For example, the ``Sum`` aggregation function returns ``None``
instead of ``0`` if the ``QuerySet`` contains no entries. An exception is
``Count``, which does return ``0`` if the ``QuerySet`` is empty.
instead of ``0`` if the ``QuerySet`` contains no entries. To return another
value instead, pass a value to the ``default`` argument. An exception is
``Count``, which does return ``0`` if the ``QuerySet`` is empty. ``Count``
does not support the ``default`` argument.
All aggregates have the following parameters in common:
@ -3578,6 +3580,16 @@ rows that are aggregated.
See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for
example usage.
.. _aggregate-default:
``default``
~~~~~~~~~~~
.. versionadded:: 4.0
An optional argument that allows specifying a value to use as a default value
when the queryset (or grouping) contains no entries.
``**extra``
~~~~~~~~~~~
@ -3587,7 +3599,7 @@ by the aggregate.
``Avg``
~~~~~~~
.. class:: Avg(expression, output_field=None, distinct=False, filter=None, **extra)
.. class:: Avg(expression, output_field=None, distinct=False, filter=None, default=None, **extra)
Returns the mean value of the given expression, which must be numeric
unless you specify a different ``output_field``.
@ -3623,10 +3635,14 @@ by the aggregate.
This is the SQL equivalent of ``COUNT(DISTINCT <field>)``. The default
value is ``False``.
.. note::
The ``default`` argument is not supported.
``Max``
~~~~~~~
.. class:: Max(expression, output_field=None, filter=None, **extra)
.. class:: Max(expression, output_field=None, filter=None, default=None, **extra)
Returns the maximum value of the given expression.
@ -3636,7 +3652,7 @@ by the aggregate.
``Min``
~~~~~~~
.. class:: Min(expression, output_field=None, filter=None, **extra)
.. class:: Min(expression, output_field=None, filter=None, default=None, **extra)
Returns the minimum value of the given expression.
@ -3646,7 +3662,7 @@ by the aggregate.
``StdDev``
~~~~~~~~~~
.. class:: StdDev(expression, output_field=None, sample=False, filter=None, **extra)
.. class:: StdDev(expression, output_field=None, sample=False, filter=None, default=None, **extra)
Returns the standard deviation of the data in the provided expression.
@ -3664,7 +3680,7 @@ by the aggregate.
``Sum``
~~~~~~~
.. class:: Sum(expression, output_field=None, distinct=False, filter=None, **extra)
.. class:: Sum(expression, output_field=None, distinct=False, filter=None, default=None, **extra)
Computes the sum of all values of the given expression.
@ -3682,7 +3698,7 @@ by the aggregate.
``Variance``
~~~~~~~~~~~~
.. class:: Variance(expression, output_field=None, sample=False, filter=None, **extra)
.. class:: Variance(expression, output_field=None, sample=False, filter=None, default=None, **extra)
Returns the variance of the data in the provided expression.

View File

@ -288,6 +288,10 @@ Models
* :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet``
annotations, aggregations, and directly in filters.
* The new :ref:`default <aggregate-default>` argument for built-in aggregates
allows specifying a value to be returned when the queryset (or grouping)
contains no entries, rather than ``None``.
Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1,15 +1,19 @@
import datetime
import math
import re
from decimal import Decimal
from django.core.exceptions import FieldError
from django.db import connection
from django.db.models import (
Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField,
IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When,
Avg, Case, Count, DateField, DateTimeField, DecimalField, DurationField,
Exists, F, FloatField, IntegerField, Max, Min, OuterRef, Q, StdDev,
Subquery, Sum, TimeField, Value, Variance, When,
)
from django.db.models.expressions import Func, RawSQL
from django.db.models.functions import Coalesce, Greatest
from django.db.models.functions import (
Cast, Coalesce, Greatest, Now, Pi, TruncDate, TruncHour,
)
from django.test import TestCase
from django.test.testcases import skipUnlessDBFeature
from django.test.utils import Approximate, CaptureQueriesContext
@ -18,6 +22,20 @@ from django.utils import timezone
from .models import Author, Book, Publisher, Store
class NowUTC(Now):
template = 'CURRENT_TIMESTAMP'
output_field = DateTimeField()
def as_mysql(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template='UTC_TIMESTAMP', **extra_context)
def as_oracle(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template="CURRENT_TIMESTAMP AT TIME ZONE 'UTC'", **extra_context)
def as_postgresql(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template="STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'", **extra_context)
class AggregateTestCase(TestCase):
@classmethod
@ -1402,3 +1420,190 @@ class AggregateTestCase(TestCase):
)['latest_opening'],
datetime.datetime,
)
def test_aggregation_default_unsupported_by_count(self):
msg = 'Count does not allow default.'
with self.assertRaisesMessage(TypeError, msg):
Count('age', default=0)
def test_aggregation_default_unset(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age'),
)
self.assertIsNone(result['value'])
def test_aggregation_default_zero(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=0),
)
self.assertEqual(result['value'], 0)
def test_aggregation_default_integer(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=21),
)
self.assertEqual(result['value'], 21)
def test_aggregation_default_expression(self):
for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]:
with self.subTest(Aggregate):
result = Author.objects.filter(age__gt=100).aggregate(
value=Aggregate('age', default=Value(5) * Value(7)),
)
self.assertEqual(result['value'], 35)
def test_aggregation_default_group_by(self):
qs = Publisher.objects.values('name').annotate(
books=Count('book'),
pages=Sum('book__pages', default=0),
).filter(books=0)
self.assertSequenceEqual(
qs,
[{'name': "Jonno's House of Books", 'books': 0, 'pages': 0}],
)
def test_aggregation_default_compound_expression(self):
# Scale rating to a percentage; default to 50% if no books published.
formula = Avg('book__rating', default=2.5) * 20.0
queryset = Publisher.objects.annotate(rating=formula).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'rating'), [
{'name': 'Apress', 'rating': 85.0},
{'name': "Jonno's House of Books", 'rating': 50.0},
{'name': 'Morgan Kaufmann', 'rating': 100.0},
{'name': 'Prentice Hall', 'rating': 80.0},
{'name': 'Sams', 'rating': 60.0},
])
def test_aggregation_default_using_time_from_python(self):
expr = Min(
'store__friday_night_closing',
filter=~Q(store__name='Amazon.com'),
default=datetime.time(17),
)
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 8.0+ & MariaDB.
expr.default = Cast(expr.default, TimeField())
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '067232959', 'oldest_store_opening': datetime.time(17)},
{'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)},
])
def test_aggregation_default_using_time_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min(
'store__friday_night_closing',
filter=~Q(store__name='Amazon.com'),
default=TruncHour(NowUTC(), output_field=TimeField()),
)
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '067232959', 'oldest_store_opening': datetime.time(now.hour)},
{'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)},
{'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)},
{'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)},
])
def test_aggregation_default_using_date_from_python(self):
expr = Min('book__pubdate', default=datetime.date(1970, 1, 1))
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 5.7+ & MariaDB.
expr.default = Cast(expr.default, DateField())
queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [
{'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)},
{'name': "Jonno's House of Books", 'earliest_pubdate': datetime.date(1970, 1, 1)},
{'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)},
{'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)},
{'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)},
])
def test_aggregation_default_using_date_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min('book__pubdate', default=TruncDate(NowUTC()))
queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name')
self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [
{'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)},
{'name': "Jonno's House of Books", 'earliest_pubdate': now.date()},
{'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)},
{'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)},
{'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)},
])
def test_aggregation_default_using_datetime_from_python(self):
expr = Min(
'store__original_opening',
filter=~Q(store__name='Amazon.com'),
default=datetime.datetime(1970, 1, 1),
)
if connection.vendor == 'mysql':
# Workaround for #30224 for MySQL 8.0+ & MariaDB.
expr.default = Cast(expr.default, DateTimeField())
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '067232959', 'oldest_store_opening': datetime.datetime(1970, 1, 1)},
{'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
])
def test_aggregation_default_using_datetime_from_database(self):
now = timezone.now().astimezone(timezone.utc)
expr = Min(
'store__original_opening',
filter=~Q(store__name='Amazon.com'),
default=TruncHour(NowUTC(), output_field=DateTimeField()),
)
queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn')
self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [
{'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '067232959', 'oldest_store_opening': now.replace(minute=0, second=0, microsecond=0, tzinfo=None)},
{'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
{'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)},
{'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)},
])
def test_aggregation_default_using_duration_from_python(self):
result = Publisher.objects.filter(num_awards__gt=3).aggregate(
value=Sum('duration', default=datetime.timedelta(0)),
)
self.assertEqual(result['value'], datetime.timedelta(0))
def test_aggregation_default_using_duration_from_database(self):
result = Publisher.objects.filter(num_awards__gt=3).aggregate(
value=Sum('duration', default=Now() - Now()),
)
self.assertEqual(result['value'], datetime.timedelta(0))
def test_aggregation_default_using_decimal_from_python(self):
result = Book.objects.filter(rating__lt=3.0).aggregate(
value=Sum('price', default=Decimal('0.00')),
)
self.assertEqual(result['value'], Decimal('0.00'))
def test_aggregation_default_using_decimal_from_database(self):
result = Book.objects.filter(rating__lt=3.0).aggregate(
value=Sum('price', default=Pi()),
)
self.assertAlmostEqual(result['value'], Decimal.from_float(math.pi), places=6)
def test_aggregation_default_passed_another_aggregate(self):
result = Book.objects.aggregate(
value=Sum('price', filter=Q(rating__lt=3.0), default=Avg('pages') / 10.0),
)
self.assertAlmostEqual(result['value'], Decimal('61.72'), places=2)

View File

@ -72,6 +72,34 @@ class TestGeneralAggregate(PostgreSQLTestCase):
)
self.assertEqual(values, {'aggregation': expected_result})
def test_default_argument(self):
AggregateTestModel.objects.all().delete()
tests = [
(ArrayAgg('char_field', default=['<empty>']), ['<empty>']),
(ArrayAgg('integer_field', default=[0]), [0]),
(ArrayAgg('boolean_field', default=[False]), [False]),
(BitAnd('integer_field', default=0), 0),
(BitOr('integer_field', default=0), 0),
(BoolAnd('boolean_field', default=False), False),
(BoolOr('boolean_field', default=False), False),
(JSONBAgg('integer_field', default=Value('["<empty>"]')), ['<empty>']),
(StringAgg('char_field', delimiter=';', default=Value('<empty>')), '<empty>'),
]
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
# Empty result with non-execution optimization.
with self.assertNumQueries(0):
values = AggregateTestModel.objects.none().aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = AggregateTestModel.objects.aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']})
@ -515,6 +543,37 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
)
self.assertEqual(values, {'aggregation': expected_result})
def test_default_argument(self):
StatTestModel.objects.all().delete()
tests = [
(Corr(y='int2', x='int1', default=0), 0),
(CovarPop(y='int2', x='int1', default=0), 0),
(CovarPop(y='int2', x='int1', sample=True, default=0), 0),
(RegrAvgX(y='int2', x='int1', default=0), 0),
(RegrAvgY(y='int2', x='int1', default=0), 0),
# RegrCount() doesn't support the default argument.
(RegrIntercept(y='int2', x='int1', default=0), 0),
(RegrR2(y='int2', x='int1', default=0), 0),
(RegrSlope(y='int2', x='int1', default=0), 0),
(RegrSXX(y='int2', x='int1', default=0), 0),
(RegrSXY(y='int2', x='int1', default=0), 0),
(RegrSYY(y='int2', x='int1', default=0), 0),
]
for aggregation, expected_result in tests:
with self.subTest(aggregation=aggregation):
# Empty result with non-execution optimization.
with self.assertNumQueries(0):
values = StatTestModel.objects.none().aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
# Empty result when query must be executed.
with self.assertNumQueries(1):
values = StatTestModel.objects.aggregate(
aggregation=aggregation,
)
self.assertEqual(values, {'aggregation': expected_result})
def test_corr_general(self):
values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1'))
self.assertEqual(values, {'corr': -1.0})
@ -539,6 +598,11 @@ class TestStatisticsAggregate(PostgreSQLTestCase):
values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1'))
self.assertEqual(values, {'regrcount': 3})
def test_regr_count_default(self):
msg = 'RegrCount does not allow default.'
with self.assertRaisesMessage(TypeError, msg):
RegrCount(y='int2', x='int1', default=0)
def test_regr_intercept_general(self):
values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1'))
self.assertEqual(values, {'regrintercept': 4})