From 501a8db46595b2d5b99c1d3b1146a832f43cdf1c Mon Sep 17 00:00:00 2001 From: Nick Pope Date: Sun, 21 Feb 2021 01:38:55 +0000 Subject: [PATCH] Fixed #10929 -- Added default argument to aggregates. Thanks to Simon Charette and Adam Johnson for the reviews. --- django/contrib/postgres/aggregates/general.py | 6 +- .../contrib/postgres/aggregates/statistics.py | 8 +- django/db/backends/mysql/features.py | 11 + django/db/models/aggregates.py | 13 +- docs/ref/contrib/postgres/aggregates.txt | 82 +++---- docs/ref/models/database-functions.txt | 15 +- docs/ref/models/expressions.txt | 11 +- docs/ref/models/querysets.txt | 32 ++- docs/releases/4.0.txt | 4 + tests/aggregation/tests.py | 211 +++++++++++++++++- tests/postgres_tests/test_aggregates.py | 64 ++++++ 11 files changed, 393 insertions(+), 64 deletions(-) diff --git a/django/contrib/postgres/aggregates/general.py b/django/contrib/postgres/aggregates/general.py index 27cfe316e9..d36f71fddf 100644 --- a/django/contrib/postgres/aggregates/general.py +++ b/django/contrib/postgres/aggregates/general.py @@ -18,7 +18,7 @@ class ArrayAgg(OrderableAggMixin, Aggregate): return ArrayField(self.source_expressions[0].output_field) def convert_value(self, value, expression, connection): - if not value: + if value is None and self.default is None: return [] return value @@ -48,7 +48,7 @@ class JSONBAgg(OrderableAggMixin, Aggregate): output_field = JSONField() def convert_value(self, value, expression, connection): - if not value: + if value is None and self.default is None: return '[]' return value @@ -63,6 +63,6 @@ class StringAgg(OrderableAggMixin, Aggregate): super().__init__(expression, delimiter_expr, **extra) def convert_value(self, value, expression, connection): - if not value: + if value is None and self.default is None: return '' return value diff --git a/django/contrib/postgres/aggregates/statistics.py b/django/contrib/postgres/aggregates/statistics.py index f3e1450d17..c0aae93bd3 100644 --- a/django/contrib/postgres/aggregates/statistics.py +++ b/django/contrib/postgres/aggregates/statistics.py @@ -9,10 +9,10 @@ __all__ = [ class StatAggregate(Aggregate): output_field = FloatField() - def __init__(self, y, x, output_field=None, filter=None): + def __init__(self, y, x, output_field=None, filter=None, default=None): if not x or not y: raise ValueError('Both y and x must be provided.') - super().__init__(y, x, output_field=output_field, filter=filter) + super().__init__(y, x, output_field=output_field, filter=filter, default=default) class Corr(StatAggregate): @@ -20,9 +20,9 @@ class Corr(StatAggregate): class CovarPop(StatAggregate): - def __init__(self, y, x, sample=False, filter=None): + def __init__(self, y, x, sample=False, filter=None, default=None): self.function = 'COVAR_SAMP' if sample else 'COVAR_POP' - super().__init__(y, x, filter=filter) + super().__init__(y, x, filter=filter, default=default) class RegrAvgX(StatAggregate): diff --git a/django/db/backends/mysql/features.py b/django/db/backends/mysql/features.py index 6e628e80d1..21c063199f 100644 --- a/django/db/backends/mysql/features.py +++ b/django/db/backends/mysql/features.py @@ -88,6 +88,17 @@ class DatabaseFeatures(BaseDatabaseFeatures): 'annotations.tests.NonAggregateAnnotationTestCase.test_annotation_aggregate_with_m2o', }, }) + if not self.connection.mysql_is_mariadb and self.connection.mysql_version < (8,): + skips.update({ + 'Casting to datetime/time is not supported by MySQL < 8.0. (#30224)': { + 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_python', + 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_python', + }, + 'MySQL < 8.0 returns string type instead of datetime/time. (#30224)': { + 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_time_from_database', + 'aggregation.tests.AggregateTestCase.test_aggregation_default_using_datetime_from_database', + }, + }) if ( self.connection.mysql_is_mariadb and (10, 4, 3) < self.connection.mysql_version < (10, 5, 2) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 8598ba9178..1ae4382784 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -4,6 +4,7 @@ Classes to represent the definitions of aggregate functions. from django.core.exceptions import FieldError from django.db.models.expressions import Case, Func, Star, When from django.db.models.fields import IntegerField +from django.db.models.functions.comparison import Coalesce from django.db.models.functions.mixins import ( FixDurationInputMixin, NumericOutputFieldMixin, ) @@ -22,11 +23,14 @@ class Aggregate(Func): allow_distinct = False empty_aggregate_value = None - def __init__(self, *expressions, distinct=False, filter=None, **extra): + def __init__(self, *expressions, distinct=False, filter=None, default=None, **extra): if distinct and not self.allow_distinct: raise TypeError("%s does not allow distinct." % self.__class__.__name__) + if default is not None and self.empty_aggregate_value is not None: + raise TypeError(f'{self.__class__.__name__} does not allow default.') self.distinct = distinct self.filter = filter + self.default = default super().__init__(*expressions, **extra) def get_source_fields(self): @@ -56,7 +60,12 @@ class Aggregate(Func): before_resolved = self.get_source_expressions()[index] name = before_resolved.name if hasattr(before_resolved, 'name') else repr(before_resolved) raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (c.name, name, name)) - return c + if (default := c.default) is None: + return c + if hasattr(default, 'resolve_expression'): + default = default.resolve_expression(query, allow_joins, reuse, summarize) + c.default = None # Reset the default argument before wrapping. + return Coalesce(c, default, output_field=c._output_field_or_none) @property def default_alias(self): diff --git a/docs/ref/contrib/postgres/aggregates.txt b/docs/ref/contrib/postgres/aggregates.txt index c309dcad14..61ec86fa44 100644 --- a/docs/ref/contrib/postgres/aggregates.txt +++ b/docs/ref/contrib/postgres/aggregates.txt @@ -19,8 +19,8 @@ module. They are described in more detail in the `PostgreSQL docs .. admonition:: Common aggregate options - All aggregates have the :ref:`filter ` keyword - argument. + All aggregates have the :ref:`filter ` keyword argument + and most also have the :ref:`default ` keyword argument. General-purpose aggregation functions ===================================== @@ -28,9 +28,10 @@ General-purpose aggregation functions ``ArrayAgg`` ------------ -.. class:: ArrayAgg(expression, distinct=False, filter=None, ordering=(), **extra) +.. class:: ArrayAgg(expression, distinct=False, filter=None, default=None, ordering=(), **extra) - Returns a list of values, including nulls, concatenated into an array. + Returns a list of values, including nulls, concatenated into an array, or + ``default`` if there are no values. .. attribute:: distinct @@ -54,26 +55,26 @@ General-purpose aggregation functions ``BitAnd`` ---------- -.. class:: BitAnd(expression, filter=None, **extra) +.. class:: BitAnd(expression, filter=None, default=None, **extra) Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or - ``None`` if all values are null. + ``default`` if all values are null. ``BitOr`` --------- -.. class:: BitOr(expression, filter=None, **extra) +.. class:: BitOr(expression, filter=None, default=None, **extra) Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or - ``None`` if all values are null. + ``default`` if all values are null. ``BoolAnd`` ----------- -.. class:: BoolAnd(expression, filter=None, **extra) +.. class:: BoolAnd(expression, filter=None, default=None, **extra) - Returns ``True``, if all input values are true, ``None`` if all values are - null or if there are no values, otherwise ``False`` . + Returns ``True``, if all input values are true, ``default`` if all values + are null or if there are no values, otherwise ``False``. Usage example:: @@ -92,9 +93,9 @@ General-purpose aggregation functions ``BoolOr`` ---------- -.. class:: BoolOr(expression, filter=None, **extra) +.. class:: BoolOr(expression, filter=None, default=None, **extra) - Returns ``True`` if at least one input value is true, ``None`` if all + Returns ``True`` if at least one input value is true, ``default`` if all values are null or if there are no values, otherwise ``False``. Usage example:: @@ -114,9 +115,10 @@ General-purpose aggregation functions ``JSONBAgg`` ------------ -.. class:: JSONBAgg(expressions, distinct=False, filter=None, ordering=(), **extra) +.. class:: JSONBAgg(expressions, distinct=False, filter=None, default=None, ordering=(), **extra) - Returns the input values as a ``JSON`` array. + Returns the input values as a ``JSON`` array, or ``default`` if there are + no values. .. attribute:: distinct @@ -139,10 +141,10 @@ General-purpose aggregation functions ``StringAgg`` ------------- -.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, ordering=()) +.. class:: StringAgg(expression, delimiter, distinct=False, filter=None, default=None, ordering=()) Returns the input values concatenated into a string, separated by - the ``delimiter`` string. + the ``delimiter`` string, or ``default`` if there are no values. .. attribute:: delimiter @@ -174,17 +176,17 @@ field or an expression returning a numeric data. Both are required. ``Corr`` -------- -.. class:: Corr(y, x, filter=None) +.. class:: Corr(y, x, filter=None, default=None) - Returns the correlation coefficient as a ``float``, or ``None`` if there + Returns the correlation coefficient as a ``float``, or ``default`` if there aren't any matching rows. ``CovarPop`` ------------ -.. class:: CovarPop(y, x, sample=False, filter=None) +.. class:: CovarPop(y, x, sample=False, filter=None, default=None) - Returns the population covariance as a ``float``, or ``None`` if there + Returns the population covariance as a ``float``, or ``default`` if there aren't any matching rows. Has one optional argument: @@ -198,18 +200,18 @@ field or an expression returning a numeric data. Both are required. ``RegrAvgX`` ------------ -.. class:: RegrAvgX(y, x, filter=None) +.. class:: RegrAvgX(y, x, filter=None, default=None) Returns the average of the independent variable (``sum(x)/N``) as a - ``float``, or ``None`` if there aren't any matching rows. + ``float``, or ``default`` if there aren't any matching rows. ``RegrAvgY`` ------------ -.. class:: RegrAvgY(y, x, filter=None) +.. class:: RegrAvgY(y, x, filter=None, default=None) Returns the average of the dependent variable (``sum(y)/N``) as a - ``float``, or ``None`` if there aren't any matching rows. + ``float``, or ``default`` if there aren't any matching rows. ``RegrCount`` ------------- @@ -219,56 +221,60 @@ field or an expression returning a numeric data. Both are required. Returns an ``int`` of the number of input rows in which both expressions are not null. + .. note:: + + The ``default`` argument is not supported. + ``RegrIntercept`` ----------------- -.. class:: RegrIntercept(y, x, filter=None) +.. class:: RegrIntercept(y, x, filter=None, default=None) Returns the y-intercept of the least-squares-fit linear equation determined - by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any + by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any matching rows. ``RegrR2`` ---------- -.. class:: RegrR2(y, x, filter=None) +.. class:: RegrR2(y, x, filter=None, default=None) Returns the square of the correlation coefficient as a ``float``, or - ``None`` if there aren't any matching rows. + ``default`` if there aren't any matching rows. ``RegrSlope`` ------------- -.. class:: RegrSlope(y, x, filter=None) +.. class:: RegrSlope(y, x, filter=None, default=None) Returns the slope of the least-squares-fit linear equation determined - by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any + by the ``(x, y)`` pairs as a ``float``, or ``default`` if there aren't any matching rows. ``RegrSXX`` ----------- -.. class:: RegrSXX(y, x, filter=None) +.. class:: RegrSXX(y, x, filter=None, default=None) Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent - variable) as a ``float``, or ``None`` if there aren't any matching rows. + variable) as a ``float``, or ``default`` if there aren't any matching rows. ``RegrSXY`` ----------- -.. class:: RegrSXY(y, x, filter=None) +.. class:: RegrSXY(y, x, filter=None, default=None) Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent - times dependent variable) as a ``float``, or ``None`` if there aren't any - matching rows. + times dependent variable) as a ``float``, or ``default`` if there aren't + any matching rows. ``RegrSYY`` ----------- -.. class:: RegrSYY(y, x, filter=None) +.. class:: RegrSYY(y, x, filter=None, default=None) Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent - variable) as a ``float``, or ``None`` if there aren't any matching rows. + variable) as a ``float``, or ``default`` if there aren't any matching rows. Usage examples ============== diff --git a/docs/ref/models/database-functions.txt b/docs/ref/models/database-functions.txt index ac0c5ea4ec..18dfdae976 100644 --- a/docs/ref/models/database-functions.txt +++ b/docs/ref/models/database-functions.txt @@ -59,7 +59,7 @@ will result in a database error. Usage examples:: >>> # Get a screen name from least to most public - >>> from django.db.models import Sum, Value as V + >>> from django.db.models import Sum >>> from django.db.models.functions import Coalesce >>> Author.objects.create(name='Margaret Smith', goes_by='Maggie') >>> author = Author.objects.annotate( @@ -68,13 +68,18 @@ Usage examples:: Maggie >>> # Prevent an aggregate Sum() from returning None + >>> # The aggregate default argument uses Coalesce() under the hood. >>> aggregated = Author.objects.aggregate( - ... combined_age=Coalesce(Sum('age'), V(0)), - ... combined_age_default=Sum('age')) + ... combined_age=Sum('age'), + ... combined_age_default=Sum('age', default=0), + ... combined_age_coalesce=Coalesce(Sum('age'), 0), + ... ) >>> print(aggregated['combined_age']) - 0 - >>> print(aggregated['combined_age_default']) None + >>> print(aggregated['combined_age_default']) + 0 + >>> print(aggregated['combined_age_coalesce']) + 0 .. warning:: diff --git a/docs/ref/models/expressions.txt b/docs/ref/models/expressions.txt index 845e734dc1..482aad15d8 100644 --- a/docs/ref/models/expressions.txt +++ b/docs/ref/models/expressions.txt @@ -393,7 +393,7 @@ some complex computations:: The ``Aggregate`` API is as follows: -.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, **extra) +.. class:: Aggregate(*expressions, output_field=None, distinct=False, filter=None, default=None, **extra) .. attribute:: template @@ -452,6 +452,11 @@ The ``filter`` argument takes a :class:`Q object ` that's used to filter the rows that are aggregated. See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for example usage. +The ``default`` argument takes a value that will be passed along with the +aggregate to :class:`~django.db.models.functions.Coalesce`. This is useful for +specifying a value to be returned other than ``None`` when the queryset (or +grouping) contains no entries. + The ``**extra`` kwargs are ``key=value`` pairs that can be interpolated into the ``template`` attribute. @@ -459,6 +464,10 @@ into the ``template`` attribute. Support for transforms of the field was added. +.. versionchanged:: 4.0 + + The ``default`` argument was added. + Creating your own Aggregate Functions ------------------------------------- diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 408224aed1..efa28ee145 100644 --- a/docs/ref/models/querysets.txt +++ b/docs/ref/models/querysets.txt @@ -3540,8 +3540,10 @@ documentation to learn how to create your aggregates. Aggregation functions return ``None`` when used with an empty ``QuerySet``. For example, the ``Sum`` aggregation function returns ``None`` - instead of ``0`` if the ``QuerySet`` contains no entries. An exception is - ``Count``, which does return ``0`` if the ``QuerySet`` is empty. + instead of ``0`` if the ``QuerySet`` contains no entries. To return another + value instead, pass a value to the ``default`` argument. An exception is + ``Count``, which does return ``0`` if the ``QuerySet`` is empty. ``Count`` + does not support the ``default`` argument. All aggregates have the following parameters in common: @@ -3578,6 +3580,16 @@ rows that are aggregated. See :ref:`conditional-aggregation` and :ref:`filtering-on-annotations` for example usage. +.. _aggregate-default: + +``default`` +~~~~~~~~~~~ + +.. versionadded:: 4.0 + +An optional argument that allows specifying a value to use as a default value +when the queryset (or grouping) contains no entries. + ``**extra`` ~~~~~~~~~~~ @@ -3587,7 +3599,7 @@ by the aggregate. ``Avg`` ~~~~~~~ -.. class:: Avg(expression, output_field=None, distinct=False, filter=None, **extra) +.. class:: Avg(expression, output_field=None, distinct=False, filter=None, default=None, **extra) Returns the mean value of the given expression, which must be numeric unless you specify a different ``output_field``. @@ -3623,10 +3635,14 @@ by the aggregate. This is the SQL equivalent of ``COUNT(DISTINCT )``. The default value is ``False``. + .. note:: + + The ``default`` argument is not supported. + ``Max`` ~~~~~~~ -.. class:: Max(expression, output_field=None, filter=None, **extra) +.. class:: Max(expression, output_field=None, filter=None, default=None, **extra) Returns the maximum value of the given expression. @@ -3636,7 +3652,7 @@ by the aggregate. ``Min`` ~~~~~~~ -.. class:: Min(expression, output_field=None, filter=None, **extra) +.. class:: Min(expression, output_field=None, filter=None, default=None, **extra) Returns the minimum value of the given expression. @@ -3646,7 +3662,7 @@ by the aggregate. ``StdDev`` ~~~~~~~~~~ -.. class:: StdDev(expression, output_field=None, sample=False, filter=None, **extra) +.. class:: StdDev(expression, output_field=None, sample=False, filter=None, default=None, **extra) Returns the standard deviation of the data in the provided expression. @@ -3664,7 +3680,7 @@ by the aggregate. ``Sum`` ~~~~~~~ -.. class:: Sum(expression, output_field=None, distinct=False, filter=None, **extra) +.. class:: Sum(expression, output_field=None, distinct=False, filter=None, default=None, **extra) Computes the sum of all values of the given expression. @@ -3682,7 +3698,7 @@ by the aggregate. ``Variance`` ~~~~~~~~~~~~ -.. class:: Variance(expression, output_field=None, sample=False, filter=None, **extra) +.. class:: Variance(expression, output_field=None, sample=False, filter=None, default=None, **extra) Returns the variance of the data in the provided expression. diff --git a/docs/releases/4.0.txt b/docs/releases/4.0.txt index 9f63a3d3ed..4122c9b419 100644 --- a/docs/releases/4.0.txt +++ b/docs/releases/4.0.txt @@ -288,6 +288,10 @@ Models * :class:`~django.db.models.Lookup` expressions may now be used in ``QuerySet`` annotations, aggregations, and directly in filters. +* The new :ref:`default ` argument for built-in aggregates + allows specifying a value to be returned when the queryset (or grouping) + contains no entries, rather than ``None``. + Requests and Responses ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py index f7b2331d34..2de80f81db 100644 --- a/tests/aggregation/tests.py +++ b/tests/aggregation/tests.py @@ -1,15 +1,19 @@ import datetime +import math import re from decimal import Decimal from django.core.exceptions import FieldError from django.db import connection from django.db.models import ( - Avg, Case, Count, DecimalField, DurationField, Exists, F, FloatField, - IntegerField, Max, Min, OuterRef, Subquery, Sum, Value, When, + Avg, Case, Count, DateField, DateTimeField, DecimalField, DurationField, + Exists, F, FloatField, IntegerField, Max, Min, OuterRef, Q, StdDev, + Subquery, Sum, TimeField, Value, Variance, When, ) from django.db.models.expressions import Func, RawSQL -from django.db.models.functions import Coalesce, Greatest +from django.db.models.functions import ( + Cast, Coalesce, Greatest, Now, Pi, TruncDate, TruncHour, +) from django.test import TestCase from django.test.testcases import skipUnlessDBFeature from django.test.utils import Approximate, CaptureQueriesContext @@ -18,6 +22,20 @@ from django.utils import timezone from .models import Author, Book, Publisher, Store +class NowUTC(Now): + template = 'CURRENT_TIMESTAMP' + output_field = DateTimeField() + + def as_mysql(self, compiler, connection, **extra_context): + return self.as_sql(compiler, connection, template='UTC_TIMESTAMP', **extra_context) + + def as_oracle(self, compiler, connection, **extra_context): + return self.as_sql(compiler, connection, template="CURRENT_TIMESTAMP AT TIME ZONE 'UTC'", **extra_context) + + def as_postgresql(self, compiler, connection, **extra_context): + return self.as_sql(compiler, connection, template="STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'", **extra_context) + + class AggregateTestCase(TestCase): @classmethod @@ -1402,3 +1420,190 @@ class AggregateTestCase(TestCase): )['latest_opening'], datetime.datetime, ) + + def test_aggregation_default_unsupported_by_count(self): + msg = 'Count does not allow default.' + with self.assertRaisesMessage(TypeError, msg): + Count('age', default=0) + + def test_aggregation_default_unset(self): + for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: + with self.subTest(Aggregate): + result = Author.objects.filter(age__gt=100).aggregate( + value=Aggregate('age'), + ) + self.assertIsNone(result['value']) + + def test_aggregation_default_zero(self): + for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: + with self.subTest(Aggregate): + result = Author.objects.filter(age__gt=100).aggregate( + value=Aggregate('age', default=0), + ) + self.assertEqual(result['value'], 0) + + def test_aggregation_default_integer(self): + for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: + with self.subTest(Aggregate): + result = Author.objects.filter(age__gt=100).aggregate( + value=Aggregate('age', default=21), + ) + self.assertEqual(result['value'], 21) + + def test_aggregation_default_expression(self): + for Aggregate in [Avg, Max, Min, StdDev, Sum, Variance]: + with self.subTest(Aggregate): + result = Author.objects.filter(age__gt=100).aggregate( + value=Aggregate('age', default=Value(5) * Value(7)), + ) + self.assertEqual(result['value'], 35) + + def test_aggregation_default_group_by(self): + qs = Publisher.objects.values('name').annotate( + books=Count('book'), + pages=Sum('book__pages', default=0), + ).filter(books=0) + self.assertSequenceEqual( + qs, + [{'name': "Jonno's House of Books", 'books': 0, 'pages': 0}], + ) + + def test_aggregation_default_compound_expression(self): + # Scale rating to a percentage; default to 50% if no books published. + formula = Avg('book__rating', default=2.5) * 20.0 + queryset = Publisher.objects.annotate(rating=formula).order_by('name') + self.assertSequenceEqual(queryset.values('name', 'rating'), [ + {'name': 'Apress', 'rating': 85.0}, + {'name': "Jonno's House of Books", 'rating': 50.0}, + {'name': 'Morgan Kaufmann', 'rating': 100.0}, + {'name': 'Prentice Hall', 'rating': 80.0}, + {'name': 'Sams', 'rating': 60.0}, + ]) + + def test_aggregation_default_using_time_from_python(self): + expr = Min( + 'store__friday_night_closing', + filter=~Q(store__name='Amazon.com'), + default=datetime.time(17), + ) + if connection.vendor == 'mysql': + # Workaround for #30224 for MySQL 8.0+ & MariaDB. + expr.default = Cast(expr.default, TimeField()) + queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') + self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ + {'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)}, + {'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)}, + {'isbn': '067232959', 'oldest_store_opening': datetime.time(17)}, + {'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)}, + {'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)}, + {'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)}, + ]) + + def test_aggregation_default_using_time_from_database(self): + now = timezone.now().astimezone(timezone.utc) + expr = Min( + 'store__friday_night_closing', + filter=~Q(store__name='Amazon.com'), + default=TruncHour(NowUTC(), output_field=TimeField()), + ) + queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') + self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ + {'isbn': '013235613', 'oldest_store_opening': datetime.time(21, 30)}, + {'isbn': '013790395', 'oldest_store_opening': datetime.time(23, 59, 59)}, + {'isbn': '067232959', 'oldest_store_opening': datetime.time(now.hour)}, + {'isbn': '155860191', 'oldest_store_opening': datetime.time(21, 30)}, + {'isbn': '159059725', 'oldest_store_opening': datetime.time(23, 59, 59)}, + {'isbn': '159059996', 'oldest_store_opening': datetime.time(21, 30)}, + ]) + + def test_aggregation_default_using_date_from_python(self): + expr = Min('book__pubdate', default=datetime.date(1970, 1, 1)) + if connection.vendor == 'mysql': + # Workaround for #30224 for MySQL 5.7+ & MariaDB. + expr.default = Cast(expr.default, DateField()) + queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name') + self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [ + {'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)}, + {'name': "Jonno's House of Books", 'earliest_pubdate': datetime.date(1970, 1, 1)}, + {'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)}, + {'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)}, + {'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)}, + ]) + + def test_aggregation_default_using_date_from_database(self): + now = timezone.now().astimezone(timezone.utc) + expr = Min('book__pubdate', default=TruncDate(NowUTC())) + queryset = Publisher.objects.annotate(earliest_pubdate=expr).order_by('name') + self.assertSequenceEqual(queryset.values('name', 'earliest_pubdate'), [ + {'name': 'Apress', 'earliest_pubdate': datetime.date(2007, 12, 6)}, + {'name': "Jonno's House of Books", 'earliest_pubdate': now.date()}, + {'name': 'Morgan Kaufmann', 'earliest_pubdate': datetime.date(1991, 10, 15)}, + {'name': 'Prentice Hall', 'earliest_pubdate': datetime.date(1995, 1, 15)}, + {'name': 'Sams', 'earliest_pubdate': datetime.date(2008, 3, 3)}, + ]) + + def test_aggregation_default_using_datetime_from_python(self): + expr = Min( + 'store__original_opening', + filter=~Q(store__name='Amazon.com'), + default=datetime.datetime(1970, 1, 1), + ) + if connection.vendor == 'mysql': + # Workaround for #30224 for MySQL 8.0+ & MariaDB. + expr.default = Cast(expr.default, DateTimeField()) + queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') + self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ + {'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, + {'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, + {'isbn': '067232959', 'oldest_store_opening': datetime.datetime(1970, 1, 1)}, + {'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, + {'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, + {'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, + ]) + + def test_aggregation_default_using_datetime_from_database(self): + now = timezone.now().astimezone(timezone.utc) + expr = Min( + 'store__original_opening', + filter=~Q(store__name='Amazon.com'), + default=TruncHour(NowUTC(), output_field=DateTimeField()), + ) + queryset = Book.objects.annotate(oldest_store_opening=expr).order_by('isbn') + self.assertSequenceEqual(queryset.values('isbn', 'oldest_store_opening'), [ + {'isbn': '013235613', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, + {'isbn': '013790395', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, + {'isbn': '067232959', 'oldest_store_opening': now.replace(minute=0, second=0, microsecond=0, tzinfo=None)}, + {'isbn': '155860191', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, + {'isbn': '159059725', 'oldest_store_opening': datetime.datetime(2001, 3, 15, 11, 23, 37)}, + {'isbn': '159059996', 'oldest_store_opening': datetime.datetime(1945, 4, 25, 16, 24, 14)}, + ]) + + def test_aggregation_default_using_duration_from_python(self): + result = Publisher.objects.filter(num_awards__gt=3).aggregate( + value=Sum('duration', default=datetime.timedelta(0)), + ) + self.assertEqual(result['value'], datetime.timedelta(0)) + + def test_aggregation_default_using_duration_from_database(self): + result = Publisher.objects.filter(num_awards__gt=3).aggregate( + value=Sum('duration', default=Now() - Now()), + ) + self.assertEqual(result['value'], datetime.timedelta(0)) + + def test_aggregation_default_using_decimal_from_python(self): + result = Book.objects.filter(rating__lt=3.0).aggregate( + value=Sum('price', default=Decimal('0.00')), + ) + self.assertEqual(result['value'], Decimal('0.00')) + + def test_aggregation_default_using_decimal_from_database(self): + result = Book.objects.filter(rating__lt=3.0).aggregate( + value=Sum('price', default=Pi()), + ) + self.assertAlmostEqual(result['value'], Decimal.from_float(math.pi), places=6) + + def test_aggregation_default_passed_another_aggregate(self): + result = Book.objects.aggregate( + value=Sum('price', filter=Q(rating__lt=3.0), default=Avg('pages') / 10.0), + ) + self.assertAlmostEqual(result['value'], Decimal('61.72'), places=2) diff --git a/tests/postgres_tests/test_aggregates.py b/tests/postgres_tests/test_aggregates.py index d47c24203b..393e3d38e8 100644 --- a/tests/postgres_tests/test_aggregates.py +++ b/tests/postgres_tests/test_aggregates.py @@ -72,6 +72,34 @@ class TestGeneralAggregate(PostgreSQLTestCase): ) self.assertEqual(values, {'aggregation': expected_result}) + def test_default_argument(self): + AggregateTestModel.objects.all().delete() + tests = [ + (ArrayAgg('char_field', default=['']), ['']), + (ArrayAgg('integer_field', default=[0]), [0]), + (ArrayAgg('boolean_field', default=[False]), [False]), + (BitAnd('integer_field', default=0), 0), + (BitOr('integer_field', default=0), 0), + (BoolAnd('boolean_field', default=False), False), + (BoolOr('boolean_field', default=False), False), + (JSONBAgg('integer_field', default=Value('[""]')), ['']), + (StringAgg('char_field', delimiter=';', default=Value('')), ''), + ] + for aggregation, expected_result in tests: + with self.subTest(aggregation=aggregation): + # Empty result with non-execution optimization. + with self.assertNumQueries(0): + values = AggregateTestModel.objects.none().aggregate( + aggregation=aggregation, + ) + self.assertEqual(values, {'aggregation': expected_result}) + # Empty result when query must be executed. + with self.assertNumQueries(1): + values = AggregateTestModel.objects.aggregate( + aggregation=aggregation, + ) + self.assertEqual(values, {'aggregation': expected_result}) + def test_array_agg_charfield(self): values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field')) self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo4', 'Foo3']}) @@ -515,6 +543,37 @@ class TestStatisticsAggregate(PostgreSQLTestCase): ) self.assertEqual(values, {'aggregation': expected_result}) + def test_default_argument(self): + StatTestModel.objects.all().delete() + tests = [ + (Corr(y='int2', x='int1', default=0), 0), + (CovarPop(y='int2', x='int1', default=0), 0), + (CovarPop(y='int2', x='int1', sample=True, default=0), 0), + (RegrAvgX(y='int2', x='int1', default=0), 0), + (RegrAvgY(y='int2', x='int1', default=0), 0), + # RegrCount() doesn't support the default argument. + (RegrIntercept(y='int2', x='int1', default=0), 0), + (RegrR2(y='int2', x='int1', default=0), 0), + (RegrSlope(y='int2', x='int1', default=0), 0), + (RegrSXX(y='int2', x='int1', default=0), 0), + (RegrSXY(y='int2', x='int1', default=0), 0), + (RegrSYY(y='int2', x='int1', default=0), 0), + ] + for aggregation, expected_result in tests: + with self.subTest(aggregation=aggregation): + # Empty result with non-execution optimization. + with self.assertNumQueries(0): + values = StatTestModel.objects.none().aggregate( + aggregation=aggregation, + ) + self.assertEqual(values, {'aggregation': expected_result}) + # Empty result when query must be executed. + with self.assertNumQueries(1): + values = StatTestModel.objects.aggregate( + aggregation=aggregation, + ) + self.assertEqual(values, {'aggregation': expected_result}) + def test_corr_general(self): values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1')) self.assertEqual(values, {'corr': -1.0}) @@ -539,6 +598,11 @@ class TestStatisticsAggregate(PostgreSQLTestCase): values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1')) self.assertEqual(values, {'regrcount': 3}) + def test_regr_count_default(self): + msg = 'RegrCount does not allow default.' + with self.assertRaisesMessage(TypeError, msg): + RegrCount(y='int2', x='int1', default=0) + def test_regr_intercept_general(self): values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1')) self.assertEqual(values, {'regrintercept': 4})