Fixed #30056 -- Added SQLite support for StdDev and Variance functions.

This commit is contained in:
Nick Pope 2018-12-19 23:01:44 +00:00 committed by Tim Graham
parent 5d25804eaf
commit 83677faf86
7 changed files with 21 additions and 46 deletions

View File

@ -1,5 +1,4 @@
from django.db.models.aggregates import StdDev from django.db.utils import ProgrammingError
from django.db.utils import NotSupportedError, ProgrammingError
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -298,12 +297,3 @@ class BaseDatabaseFeatures:
count, = cursor.fetchone() count, = cursor.fetchone()
cursor.execute('DROP TABLE ROLLBACK_TEST') cursor.execute('DROP TABLE ROLLBACK_TEST')
return count == 0 return count == 0
@cached_property
def supports_stddev(self):
"""Confirm support for STDDEV and related stats functions."""
try:
self.connection.ops.check_expression_support(StdDev(1))
except NotSupportedError:
return False
return True

View File

@ -7,6 +7,7 @@ import functools
import math import math
import operator import operator
import re import re
import statistics
import warnings import warnings
from itertools import chain from itertools import chain
from sqlite3 import dbapi2 as Database from sqlite3 import dbapi2 as Database
@ -49,6 +50,14 @@ def none_guard(func):
return wrapper return wrapper
def list_aggregate(function):
"""
Return an aggregate class that accumulates values in a list and applies
the provided function to the data.
"""
return type('ListAggregate', (list,), {'finalize': function, 'step': list.append})
Database.register_converter("bool", b'1'.__eq__) Database.register_converter("bool", b'1'.__eq__)
Database.register_converter("time", decoder(parse_time)) Database.register_converter("time", decoder(parse_time))
Database.register_converter("datetime", decoder(parse_datetime)) Database.register_converter("datetime", decoder(parse_datetime))
@ -210,6 +219,10 @@ class DatabaseWrapper(BaseDatabaseWrapper):
conn.create_function('SIN', 1, none_guard(math.sin)) conn.create_function('SIN', 1, none_guard(math.sin))
conn.create_function('SQRT', 1, none_guard(math.sqrt)) conn.create_function('SQRT', 1, none_guard(math.sqrt))
conn.create_function('TAN', 1, none_guard(math.tan)) conn.create_function('TAN', 1, none_guard(math.tan))
conn.create_aggregate('STDDEV_POP', 1, list_aggregate(statistics.pstdev))
conn.create_aggregate('STDDEV_SAMP', 1, list_aggregate(statistics.stdev))
conn.create_aggregate('VAR_POP', 1, list_aggregate(statistics.pvariance))
conn.create_aggregate('VAR_SAMP', 1, list_aggregate(statistics.variance))
conn.execute('PRAGMA foreign_keys = ON') conn.execute('PRAGMA foreign_keys = ON')
return conn return conn

View File

@ -1,8 +1,6 @@
import sys import sys
from django.db import utils
from django.db.backends.base.features import BaseDatabaseFeatures from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
from .base import Database from .base import Database
@ -41,22 +39,3 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# reasonably performant way. # reasonably performant way.
supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0) supports_pragma_foreign_key_check = Database.sqlite_version_info >= (3, 20, 0)
can_defer_constraint_checks = supports_pragma_foreign_key_check can_defer_constraint_checks = supports_pragma_foreign_key_check
@cached_property
def supports_stddev(self):
"""
Confirm support for STDDEV and related stats functions.
SQLite supports STDDEV as an extension package; so
connection.ops.check_expression_support() can't unilaterally
rule out support for STDDEV. Manually check whether the call works.
"""
with self.connection.cursor() as cursor:
cursor.execute('CREATE TABLE STDDEV_TEST (X INT)')
try:
cursor.execute('SELECT STDDEV(*) FROM STDDEV_TEST')
has_support = True
except utils.DatabaseError:
has_support = False
cursor.execute('DROP TABLE STDDEV_TEST')
return has_support

View File

@ -3400,12 +3400,9 @@ by the aggregate.
By default, ``StdDev`` returns the population standard deviation. However, By default, ``StdDev`` returns the population standard deviation. However,
if ``sample=True``, the return value will be the sample standard deviation. if ``sample=True``, the return value will be the sample standard deviation.
.. admonition:: SQLite .. versionchanged:: 2.2
SQLite doesn't provide ``StdDev`` out of the box. An implementation SQLite support was added.
is available as an extension module for SQLite. Consult the `SQLite
documentation`_ for instructions on obtaining and installing this
extension.
``Sum`` ``Sum``
~~~~~~~ ~~~~~~~
@ -3434,14 +3431,9 @@ by the aggregate.
By default, ``Variance`` returns the population variance. However, By default, ``Variance`` returns the population variance. However,
if ``sample=True``, the return value will be the sample variance. if ``sample=True``, the return value will be the sample variance.
.. admonition:: SQLite .. versionchanged:: 2.2
SQLite doesn't provide ``Variance`` out of the box. An implementation SQLite support was added.
is available as an extension module for SQLite. Consult the `SQLite
documentation`_ for instructions on obtaining and installing this
extension.
.. _SQLite documentation: https://www.sqlite.org/contrib
Query-related tools Query-related tools
=================== ===================

View File

@ -235,6 +235,9 @@ Models
``Model.delete()``. This improves the performance of autocommit by reducing ``Model.delete()``. This improves the performance of autocommit by reducing
the number of database round trips. the number of database round trips.
* Added SQLite support for the :class:`~django.db.models.StdDev` and
:class:`~django.db.models.Variance` functions.
Requests and Responses Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~

View File

@ -1116,7 +1116,6 @@ class AggregationTests(TestCase):
lambda b: (b.name, b.authorCount) lambda b: (b.name, b.authorCount)
) )
@skipUnlessDBFeature('supports_stddev')
def test_stddev(self): def test_stddev(self):
self.assertEqual( self.assertEqual(
Book.objects.aggregate(StdDev('pages')), Book.objects.aggregate(StdDev('pages')),

View File

@ -340,7 +340,6 @@ class BackendTestCase(TransactionTestCase):
def test_cached_db_features(self): def test_cached_db_features(self):
self.assertIn(connection.features.supports_transactions, (True, False)) self.assertIn(connection.features.supports_transactions, (True, False))
self.assertIn(connection.features.supports_stddev, (True, False))
self.assertIn(connection.features.can_introspect_foreign_keys, (True, False)) self.assertIn(connection.features.can_introspect_foreign_keys, (True, False))
def test_duplicate_table_error(self): def test_duplicate_table_error(self):