Fixed #24301 -- Added PostgreSQL-specific aggregate functions

This commit is contained in:
Andriy Sokolovskiy 2015-02-08 17:21:48 +02:00 committed by Tim Graham
parent 931a340f1f
commit e4cf8c8420
10 changed files with 646 additions and 1 deletions

View File

@ -49,7 +49,7 @@ answer newbie questions, and generally made Django that much better:
Andrew Godwin <andrew@aeracode.org>
Andrew Pinkham <http://AndrewsForge.com>
Andrews Medina <andrewsmedina@gmail.com>
Andriy Sokolovskiy <sokandpal@yandex.ru>
Andriy Sokolovskiy <me@asokolovskiy.com>
Andy Dustman <farcepest@gmail.com>
Andy Gayton <andy-django@thecablelounge.com>
andy@jadedplanet.net

View File

@ -0,0 +1,2 @@
from .general import * # NOQA
from .statistics import * # NOQA

View File

@ -0,0 +1,43 @@
from django.db.models.aggregates import Aggregate
__all__ = [
'ArrayAgg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'StringAgg',
]
class ArrayAgg(Aggregate):
function = 'ARRAY_AGG'
def convert_value(self, value, expression, connection, context):
if not value:
return []
return value
class BitAnd(Aggregate):
function = 'BIT_AND'
class BitOr(Aggregate):
function = 'BIT_OR'
class BoolAnd(Aggregate):
function = 'BOOL_AND'
class BoolOr(Aggregate):
function = 'BOOL_OR'
class StringAgg(Aggregate):
function = 'STRING_AGG'
template = "%(function)s(%(expressions)s, '%(delimiter)s')"
def __init__(self, expression, delimiter, **extra):
super(StringAgg, self).__init__(expression, delimiter=delimiter, **extra)
def convert_value(self, value, expression, connection, context):
if not value:
return ''
return value

View File

@ -0,0 +1,80 @@
from django.db.models import FloatField, IntegerField
from django.db.models.aggregates import Aggregate
__all__ = [
'CovarPop', 'Corr', 'RegrAvgX', 'RegrAvgY', 'RegrCount', 'RegrIntercept',
'RegrR2', 'RegrSlope', 'RegrSXX', 'RegrSXY', 'RegrSYY', 'StatAggregate',
]
class StatAggregate(Aggregate):
def __init__(self, y, x, output_field=FloatField()):
if not x or not y:
raise ValueError('Both y and x must be provided.')
super(StatAggregate, self).__init__(y=y, x=x, output_field=output_field)
self.x = x
self.y = y
self.source_expressions = self._parse_expressions(self.y, self.x)
def get_source_expressions(self):
return self.y, self.x
def set_source_expressions(self, exprs):
self.y, self.x = exprs
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
return super(Aggregate, self).resolve_expression(query, allow_joins, reuse, summarize)
class Corr(StatAggregate):
function = 'CORR'
class CovarPop(StatAggregate):
def __init__(self, y, x, sample=False):
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
super(CovarPop, self).__init__(y, x)
class RegrAvgX(StatAggregate):
function = 'REGR_AVGX'
class RegrAvgY(StatAggregate):
function = 'REGR_AVGY'
class RegrCount(StatAggregate):
function = 'REGR_COUNT'
def __init__(self, y, x):
super(RegrCount, self).__init__(y=y, x=x, output_field=IntegerField())
def convert_value(self, value, expression, connection, context):
if value is None:
return 0
return int(value)
class RegrIntercept(StatAggregate):
function = 'REGR_INTERCEPT'
class RegrR2(StatAggregate):
function = 'REGR_R2'
class RegrSlope(StatAggregate):
function = 'REGR_SLOPE'
class RegrSXX(StatAggregate):
function = 'REGR_SXX'
class RegrSXY(StatAggregate):
function = 'REGR_SXY'
class RegrSYY(StatAggregate):
function = 'REGR_SYY'

View File

@ -0,0 +1,212 @@
=========================================
PostgreSQL specific aggregation functions
=========================================
.. module:: django.contrib.postgres.aggregates
:synopsis: PostgreSQL specific aggregation functions
.. versionadded:: 1.9
These functions are described in more detail in the `PostgreSQL docs
<http://www.postgresql.org/docs/current/static/functions-aggregate.html>`_.
.. note::
All functions come without default aliases, so you must explicitly provide
one. For example::
>>> SomeModel.objects.aggregate(arr=ArrayAgg('somefield'))
{'arr': [0, 1, 2]}
General-purpose aggregation functions
-------------------------------------
ArrayAgg
~~~~~~~~
.. class:: ArrayAgg(expression, **extra)
Returns a list of values, including nulls, concatenated into an array.
BitAnd
~~~~~~
.. class:: BitAnd(expression, **extra)
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
``None`` if all values are null.
BitOr
~~~~~
.. class:: BitOr(expression, **extra)
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
``None`` if all values are null.
BoolAnd
~~~~~~~~
.. class:: BoolAnd(expression, **extra)
Returns ``True``, if all input values are true, ``None`` if all values are
null or if there are no values, otherwise ``False`` .
BoolOr
~~~~~~
.. class:: BoolOr(expression, **extra)
Returns ``True`` if at least one input value is true, ``None`` if all
values are null or if there are no values, otherwise ``False``.
StringAgg
~~~~~~~~~
.. class:: StringAgg(expression, delimiter)
Returns the input values concatenated into a string, separated by
the ``delimiter`` string.
.. attribute:: delimiter
Required argument. Needs to be a string.
Aggregate functions for statistics
----------------------------------
``y`` and ``x``
~~~~~~~~~~~~~~~
The arguments ``y`` and ``x`` for all these functions can be the name of a
field or an expression returning a numeric data. Both are required.
Corr
~~~~
.. class:: Corr(y, x)
Returns the correlation coefficient as a ``float``, or ``None`` if there
aren't any matching rows.
CovarPop
~~~~~~~~
.. class:: CovarPop(y, x, sample=False)
Returns the population covariance as a ``float``, or ``None`` if there
aren't any matching rows.
Has one optional argument:
.. attribute:: sample
By default ``CovarPop`` returns the general population covariance.
However, if ``sample=True``, the return value will be the sample
population covariance.
RegrAvgX
~~~~~~~~
.. class:: RegrAvgX(y, x)
Returns the average of the independent variable (``sum(x)/N``) as a
``float``, or ``None`` if there aren't any matching rows.
RegrAvgY
~~~~~~~~
.. class:: RegrAvgY(y, x)
Returns the average of the independent variable (``sum(y)/N``) as a
``float``, or ``None`` if there aren't any matching rows.
RegrCount
~~~~~~~~~
.. class:: RegrCount(y, x)
Returns an ``int`` of the number of input rows in which both expressions
are not null.
RegrIntercept
~~~~~~~~~~~~~
.. class:: RegrIntercept(y, x)
Returns the y-intercept of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
matching rows.
RegrR2
~~~~~~
.. class:: RegrR2(y, x)
Returns the square of the correlation coefficient as a ``float``, or
``None`` if there aren't any matching rows.
RegrSlope
~~~~~~~~~
.. class:: RegrSlope(y, x)
Returns the slope of the least-squares-fit linear equation determined
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
matching rows.
RegrSXX
~~~~~~~
.. class:: RegrSXX(y, x)
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
variable) as a ``float``, or ``None`` if there aren't any matching rows.
RegrSXY
~~~~~~~
.. class:: RegrSXY(y, x)
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
times dependent variable) as a ``float``, or ``None`` if there aren't any
matching rows.
RegrSYY
~~~~~~~
.. class:: RegrSYY(y, x)
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
variable) as a ``float``, or ``None`` if there aren't any matching rows.
Usage examples
--------------
We will use this example table::
| FIELD1 | FIELD2 | FIELD3 |
|--------|--------|--------|
| foo | 1 | 13 |
| bar | 2 | (null) |
| test | 3 | 13 |
Here's some examples of some of the general-purpose aggregation functions::
>>> TestModel.objects.aggregate(result=StringAgg('field1', delimiter=';'))
{'result': 'foo;bar;test'}
>>> TestModel.objects.aggregate(result=ArrayAgg('field2'))
{'result': [1, 2, 3]}
>>> TestModel.objects.aggregate(result=ArrayAgg('field1'))
{'result': ['foo', 'bar', 'test']}
The next example shows the usage of statistical aggregate functions. The
underlying math will be not described (you can read about this, for example, at
`wikipedia <http://en.wikipedia.org/wiki/Regression_analysis>`_)::
>>> TestModel.objects.aggregate(count=RegrCount(y='field3', x='field2'))
{'count': 2}
>>> TestModel.objects.aggregate(avgx=RegrAvgX(y='field3', x='field2'),
... avgy=RegrAvgY(y='field3', x='field2'))
{'avgx': 2, 'avgy': 13}

View File

@ -31,6 +31,7 @@ Psycopg2 2.5 or higher is required.
.. toctree::
:maxdepth: 2
aggregates
fields
forms
lookups

View File

@ -69,6 +69,11 @@ Minor features
* ...
:mod:`django.contrib.postgres`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* Added :doc:`/ref/contrib/postgres/aggregates`.
:mod:`django.contrib.redirects`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -105,6 +105,24 @@ class Migration(migrations.Migration):
options=None,
bases=None,
),
migrations.CreateModel(
name='AggregateTestModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('boolean_field', models.NullBooleanField()),
('char_field', models.CharField(max_length=30, blank=True)),
('integer_field', models.IntegerField(null=True)),
]
),
migrations.CreateModel(
name='StatTestModel',
fields=[
('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)),
('int1', models.IntegerField()),
('int2', models.IntegerField()),
('related_field', models.ForeignKey('postgres_tests.AggregateTestModel', null=True)),
]
),
]
pg_92_operations = [

View File

@ -62,3 +62,21 @@ else:
class ArrayFieldSubclass(ArrayField):
def __init__(self, *args, **kwargs):
super(ArrayFieldSubclass, self).__init__(models.IntegerField())
class AggregateTestModel(models.Model):
"""
To test postgres-specific general aggregation functions
"""
char_field = models.CharField(max_length=30, blank=True)
integer_field = models.IntegerField(null=True)
boolean_field = models.NullBooleanField()
class StatTestModel(models.Model):
"""
To test postgres-specific aggregation functions for statistics
"""
int1 = models.IntegerField()
int2 = models.IntegerField()
related_field = models.ForeignKey(AggregateTestModel, null=True)

View File

@ -0,0 +1,266 @@
from django.contrib.postgres.aggregates import (
ArrayAgg, BitAnd, BitOr, BoolAnd, BoolOr, Corr, CovarPop, RegrAvgX,
RegrAvgY, RegrCount, RegrIntercept, RegrR2, RegrSlope, RegrSXX, RegrSXY,
RegrSYY, StatAggregate, StringAgg,
)
from django.db.models.expressions import F, Value
from django.test import TestCase
from django.test.utils import Approximate
from .models import AggregateTestModel, StatTestModel
class TestGeneralAggregate(TestCase):
@classmethod
def setUpTestData(cls):
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo1', integer_field=0)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo2', integer_field=1)
AggregateTestModel.objects.create(boolean_field=False, char_field='Foo3', integer_field=2)
AggregateTestModel.objects.create(boolean_field=True, char_field='Foo4', integer_field=0)
def test_array_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
self.assertEqual(values, {'arrayagg': ['Foo1', 'Foo2', 'Foo3', 'Foo4']})
def test_array_agg_integerfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
self.assertEqual(values, {'arrayagg': [0, 1, 2, 0]})
def test_array_agg_booleanfield(self):
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
self.assertEqual(values, {'arrayagg': [True, False, False, True]})
def test_array_agg_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('char_field'))
self.assertEqual(values, {'arrayagg': []})
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('integer_field'))
self.assertEqual(values, {'arrayagg': []})
values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg('boolean_field'))
self.assertEqual(values, {'arrayagg': []})
def test_bit_and_general(self):
values = AggregateTestModel.objects.filter(
integer_field__in=[0, 1]).aggregate(bitand=BitAnd('integer_field'))
self.assertEqual(values, {'bitand': 0})
def test_bit_and_on_only_true_values(self):
values = AggregateTestModel.objects.filter(
integer_field=1).aggregate(bitand=BitAnd('integer_field'))
self.assertEqual(values, {'bitand': 1})
def test_bit_and_on_only_false_values(self):
values = AggregateTestModel.objects.filter(
integer_field=0).aggregate(bitand=BitAnd('integer_field'))
self.assertEqual(values, {'bitand': 0})
def test_bit_and_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(bitand=BitAnd('integer_field'))
self.assertEqual(values, {'bitand': None})
def test_bit_or_general(self):
values = AggregateTestModel.objects.filter(
integer_field__in=[0, 1]).aggregate(bitor=BitOr('integer_field'))
self.assertEqual(values, {'bitor': 1})
def test_bit_or_on_only_true_values(self):
values = AggregateTestModel.objects.filter(
integer_field=1).aggregate(bitor=BitOr('integer_field'))
self.assertEqual(values, {'bitor': 1})
def test_bit_or_on_only_false_values(self):
values = AggregateTestModel.objects.filter(
integer_field=0).aggregate(bitor=BitOr('integer_field'))
self.assertEqual(values, {'bitor': 0})
def test_bit_or_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(bitor=BitOr('integer_field'))
self.assertEqual(values, {'bitor': None})
def test_bool_and_general(self):
values = AggregateTestModel.objects.aggregate(booland=BoolAnd('boolean_field'))
self.assertEqual(values, {'booland': False})
def test_bool_and_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(booland=BoolAnd('boolean_field'))
self.assertEqual(values, {'booland': None})
def test_bool_or_general(self):
values = AggregateTestModel.objects.aggregate(boolor=BoolOr('boolean_field'))
self.assertEqual(values, {'boolor': True})
def test_bool_or_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(boolor=BoolOr('boolean_field'))
self.assertEqual(values, {'boolor': None})
def test_string_agg_requires_delimiter(self):
with self.assertRaises(TypeError):
AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field'))
def test_string_agg_charfield(self):
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
self.assertEqual(values, {'stringagg': 'Foo1;Foo2;Foo3;Foo4'})
def test_string_agg_empty_result(self):
AggregateTestModel.objects.all().delete()
values = AggregateTestModel.objects.aggregate(stringagg=StringAgg('char_field', delimiter=';'))
self.assertEqual(values, {'stringagg': ''})
class TestStatisticsAggregate(TestCase):
@classmethod
def setUpTestData(cls):
StatTestModel.objects.create(
int1=1,
int2=3,
related_field=AggregateTestModel.objects.create(integer_field=0),
)
StatTestModel.objects.create(
int1=2,
int2=2,
related_field=AggregateTestModel.objects.create(integer_field=1),
)
StatTestModel.objects.create(
int1=3,
int2=1,
related_field=AggregateTestModel.objects.create(integer_field=2),
)
# Tests for base class (StatAggregate)
def test_missing_arguments_raises_exception(self):
with self.assertRaisesMessage(ValueError, 'Both y and x must be provided.'):
StatAggregate(x=None, y=None)
def test_correct_source_expressions(self):
func = StatAggregate(x='test', y=13)
self.assertIsInstance(func.source_expressions[0], Value)
self.assertIsInstance(func.source_expressions[1], F)
def test_alias_is_required(self):
class SomeFunc(StatAggregate):
function = 'TEST'
with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'):
StatTestModel.objects.aggregate(SomeFunc(y='int2', x='int1'))
# Test aggregates
def test_corr_general(self):
values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1'))
self.assertEqual(values, {'corr': -1.0})
def test_corr_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(corr=Corr(y='int2', x='int1'))
self.assertEqual(values, {'corr': None})
def test_covar_pop_general(self):
values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1'))
self.assertEqual(values, {'covarpop': Approximate(-0.66, places=1)})
def test_covar_pop_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1'))
self.assertEqual(values, {'covarpop': None})
def test_covar_pop_sample(self):
values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1', sample=True))
self.assertEqual(values, {'covarpop': -1.0})
def test_covar_pop_sample_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(covarpop=CovarPop(y='int2', x='int1', sample=True))
self.assertEqual(values, {'covarpop': None})
def test_regr_avgx_general(self):
values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y='int2', x='int1'))
self.assertEqual(values, {'regravgx': 2.0})
def test_regr_avgx_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y='int2', x='int1'))
self.assertEqual(values, {'regravgx': None})
def test_regr_avgy_general(self):
values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y='int2', x='int1'))
self.assertEqual(values, {'regravgy': 2.0})
def test_regr_avgy_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y='int2', x='int1'))
self.assertEqual(values, {'regravgy': None})
def test_regr_count_general(self):
values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1'))
self.assertEqual(values, {'regrcount': 3})
def test_regr_count_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrcount=RegrCount(y='int2', x='int1'))
self.assertEqual(values, {'regrcount': 0})
def test_regr_intercept_general(self):
values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1'))
self.assertEqual(values, {'regrintercept': 4})
def test_regr_intercept_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrintercept=RegrIntercept(y='int2', x='int1'))
self.assertEqual(values, {'regrintercept': None})
def test_regr_r2_general(self):
values = StatTestModel.objects.aggregate(regrr2=RegrR2(y='int2', x='int1'))
self.assertEqual(values, {'regrr2': 1})
def test_regr_r2_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrr2=RegrR2(y='int2', x='int1'))
self.assertEqual(values, {'regrr2': None})
def test_regr_slope_general(self):
values = StatTestModel.objects.aggregate(regrslope=RegrSlope(y='int2', x='int1'))
self.assertEqual(values, {'regrslope': -1})
def test_regr_slope_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrslope=RegrSlope(y='int2', x='int1'))
self.assertEqual(values, {'regrslope': None})
def test_regr_sxx_general(self):
values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y='int2', x='int1'))
self.assertEqual(values, {'regrsxx': 2.0})
def test_regr_sxx_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y='int2', x='int1'))
self.assertEqual(values, {'regrsxx': None})
def test_regr_sxy_general(self):
values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y='int2', x='int1'))
self.assertEqual(values, {'regrsxy': -2.0})
def test_regr_sxy_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y='int2', x='int1'))
self.assertEqual(values, {'regrsxy': None})
def test_regr_syy_general(self):
values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y='int2', x='int1'))
self.assertEqual(values, {'regrsyy': 2.0})
def test_regr_syy_empty_result(self):
StatTestModel.objects.all().delete()
values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y='int2', x='int1'))
self.assertEqual(values, {'regrsyy': None})
def test_regr_avgx_with_related_obj_and_number_as_argument(self):
"""
This is more complex test to check if JOIN on field and
number as argument works as expected.
"""
values = StatTestModel.objects.aggregate(complex_regravgx=RegrAvgX(y=5, x='related_field__integer_field'))
self.assertEqual(values, {'complex_regravgx': 1.0})