From e56810e839db2beddc8a7b6e917158855ef381dc Mon Sep 17 00:00:00 2001 From: Josh Smeaton Date: Sat, 17 Jan 2015 16:03:46 +1100 Subject: [PATCH] [1.8.x] Fixed #24154 -- Backends can now check support for expressions Backport of 8196e4bdf498acb05e6680c81f9d7bf700f4295c from master --- .../gis/db/backends/base/operations.py | 8 +++---- django/contrib/gis/db/models/aggregates.py | 3 +++ django/db/backends/base/features.py | 8 +++---- django/db/backends/base/operations.py | 14 +++++++---- django/db/backends/sqlite3/features.py | 2 +- django/db/backends/sqlite3/operations.py | 24 ++++++++++++------- django/db/models/aggregates.py | 11 --------- django/db/models/expressions.py | 15 ++++++------ django/db/models/sql/query.py | 5 ---- django/db/models/sql/where.py | 11 --------- tests/backends/tests.py | 13 +++++++--- 11 files changed, 52 insertions(+), 62 deletions(-) diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py index dc2fad025b2..4759c86b89a 100644 --- a/django/contrib/gis/db/backends/base/operations.py +++ b/django/contrib/gis/db/backends/base/operations.py @@ -98,12 +98,12 @@ class BaseSpatialOperations(object): """ raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') - def check_aggregate_support(self, aggregate): - if isinstance(aggregate, self.disallowed_aggregates): + def check_expression_support(self, expression): + if isinstance(expression, self.disallowed_aggregates): raise NotImplementedError( - "%s spatial aggregation is not supported by this database backend." % aggregate.name + "%s spatial aggregation is not supported by this database backend." % expression.name ) - super(BaseSpatialOperations, self).check_aggregate_support(aggregate) + super(BaseSpatialOperations, self).check_expression_support(expression) def spatial_aggregate_name(self, agg_name): raise NotImplementedError('Aggregate support not implemented for this spatial backend.') diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 0ec842de0fd..42198d9287c 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -9,6 +9,9 @@ class GeoAggregate(Aggregate): is_extent = False def as_sql(self, compiler, connection): + # this will be called again in parent, but it's needed now - before + # we get the spatial_aggregate_name + connection.ops.check_expression_support(self) self.function = connection.ops.spatial_aggregate_name(self.name) return super(GeoAggregate, self).as_sql(compiler, connection) diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 0f6ee0efe36..4b4d5c6d759 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -1,3 +1,5 @@ +from django.db.models.aggregates import StdDev +from django.db.models.expressions import Value from django.db.utils import ProgrammingError from django.utils.functional import cached_property @@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object): @cached_property def supports_stddev(self): """Confirm support for STDDEV and related stats functions.""" - class StdDevPop(object): - contains_aggregate = True - sql_function = 'STDDEV_POP' - try: - self.connection.ops.check_aggregate_support(StdDevPop()) + self.connection.ops.check_expression_support(StdDev(Value(1))) return True except NotImplementedError: return False diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 24bcbb3d08e..f535a8792da 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -526,12 +526,16 @@ class BaseDatabaseOperations(object): return value def check_aggregate_support(self, aggregate_func): - """Check that the backend supports the provided aggregate + return self.check_expression_support(aggregate_func) - This is used on specific backends to rule out known aggregates - that are known to have faulty implementations. If the named - aggregate function has a known problem, the backend should - raise NotImplementedError. + def check_expression_support(self, expression): + """ + Check that the backend supports the provided expression. + + This is used on specific backends to rule out known expressions + that have problematic or nonexistent implementations. If the + expression has a known problem, the backend should raise + NotImplementedError. """ pass diff --git a/django/db/backends/sqlite3/features.py b/django/db/backends/sqlite3/features.py index fa5a002603a..ee864691779 100644 --- a/django/db/backends/sqlite3/features.py +++ b/django/db/backends/sqlite3/features.py @@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): """Confirm support for STDDEV and related stats functions SQLite supports STDDEV as an extension package; so - connection.ops.check_aggregate_support() can't unilaterally + connection.ops.check_expression_support() can't unilaterally rule out support for STDDEV. We need to manually check whether the call works. """ diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index ad05f88585f..cd29092cd70 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -4,7 +4,7 @@ import datetime import uuid from django.conf import settings -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, FieldError from django.db import utils from django.db.backends import utils as backend_utils from django.db.backends.base.operations import BaseDatabaseOperations @@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations): limit = 999 if len(fields) > 1 else 500 return (limit // len(fields)) if len(fields) > 0 else len(objs) - def check_aggregate_support(self, aggregate): + def check_expression_support(self, expression): bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField) - bad_aggregates = (aggregates.Sum, aggregates.Avg, - aggregates.Variance, aggregates.StdDev) - if aggregate.refs_field(bad_aggregates, bad_fields): - raise NotImplementedError( - 'You cannot use Sum, Avg, StdDev and Variance aggregations ' - 'on date/time fields in sqlite3 ' - 'since date/time is saved as text.') + bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev) + if isinstance(expression, bad_aggregates): + try: + output_field = expression.input_field.output_field + if isinstance(output_field, bad_fields): + raise NotImplementedError( + 'You cannot use Sum, Avg, StdDev and Variance aggregations ' + 'on date/time fields in sqlite3 ' + 'since date/time is saved as text.') + except FieldError: + # not every sub-expression has an output_field which is fine to + # ignore + pass def date_extract_sql(self, lookup_type, field_name): # sqlite doesn't support extract, so we fake it with the user-defined diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index 06220123ca2..668f79f622f 100644 --- a/django/db/models/aggregates.py +++ b/django/db/models/aggregates.py @@ -25,17 +25,6 @@ class Aggregate(Func): c._patch_aggregate(query) # backward-compatibility support return c - def refs_field(self, aggregate_types, field_types): - try: - return (isinstance(self, aggregate_types) and - isinstance(self.input_field._output_field_or_none, field_types)) - except FieldError: - # Sometimes we don't know the input_field's output type (for example, - # doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField()) - # is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't - # have any output field. - return False - @property def input_field(self): return self.source_expressions[0] diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 97a2a9071d0..fb094fd4ef7 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -297,14 +297,6 @@ class BaseExpression(object): return agg, lookup return False, () - def refs_field(self, aggregate_types, field_types): - """ - Helper method for check_aggregate_support on backends - """ - return any( - node.refs_field(aggregate_types, field_types) - for node in self.get_source_expressions()) - def prepare_database_save(self, field): return self @@ -401,6 +393,7 @@ class DurationExpression(Expression): return compiler.compile(side) def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) expressions = [] expression_params = [] sql, params = self.compile(self.lhs, compiler, connection) @@ -473,6 +466,7 @@ class Func(ExpressionNode): return c def as_sql(self, compiler, connection, function=None, template=None): + connection.ops.check_expression_support(self) sql_parts = [] params = [] for arg in self.source_expressions: @@ -511,6 +505,7 @@ class Value(ExpressionNode): self.value = value def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) val = self.value # check _output_field to avoid triggering an exception if self._output_field is not None: @@ -536,6 +531,7 @@ class Value(ExpressionNode): class DurationValue(Value): def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) if (connection.features.has_native_duration_field and connection.features.driver_supports_timedelta_args): return super(DurationValue, self).as_sql(compiler, connection) @@ -650,6 +646,7 @@ class When(ExpressionNode): return c def as_sql(self, compiler, connection, template=None): + connection.ops.check_expression_support(self) template_params = {} sql_params = [] condition_sql, condition_params = compiler.compile(self.condition) @@ -715,6 +712,7 @@ class Case(ExpressionNode): return c def as_sql(self, compiler, connection, template=None, extra=None): + connection.ops.check_expression_support(self) if not self.cases: return compiler.compile(self.default) template_params = dict(extra) if extra else {} @@ -851,6 +849,7 @@ class OrderBy(BaseExpression): return [self.expression] def as_sql(self, compiler, connection): + connection.ops.check_expression_support(self) expression_sql, params = compiler.compile(self.expression) placeholders = {'expression': expression_sql} placeholders['ordering'] = 'DESC' if self.descending else 'ASC' diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 9693206b67a..0d89de2458b 100644 --- a/django/db/models/sql/query.py +++ b/django/db/models/sql/query.py @@ -230,11 +230,6 @@ class Query(object): raise ValueError("Need either using or connection") if using: connection = connections[using] - - # Check that the compiler will be able to execute the query - for alias, annotation in self.annotation_select.items(): - connection.ops.check_aggregate_support(annotation) - return connection.ops.compiler(self.compiler)(self, connection, using) def get_meta(self): diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index e3766c51d62..6a03210a93e 100644 --- a/django/db/models/sql/where.py +++ b/django/db/models/sql/where.py @@ -325,17 +325,6 @@ class WhereNode(tree.Node): def contains_aggregate(self): return self._contains_aggregate(self) - @classmethod - def _refs_field(cls, obj, aggregate_types, field_types): - if not isinstance(obj, tree.Node): - if hasattr(obj.rhs, 'refs_field'): - return obj.rhs.refs_field(aggregate_types, field_types) - return False - return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children) - - def refs_field(self, aggregate_types, field_types): - return self._refs_field(self, aggregate_types, field_types) - class EmptyWhere(WhereNode): def add(self, data, connector): diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 4a07cd5f2f3..979ebefa01c 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -128,12 +128,19 @@ class SQLiteTests(TestCase): #19360: Raise NotImplementedError when aggregating on date/time fields. """ for aggregate in (Sum, Avg, Variance, StdDev): - self.assertRaises(NotImplementedError, + self.assertRaises( + NotImplementedError, models.Item.objects.all().aggregate, aggregate('time')) - self.assertRaises(NotImplementedError, + self.assertRaises( + NotImplementedError, models.Item.objects.all().aggregate, aggregate('date')) - self.assertRaises(NotImplementedError, + self.assertRaises( + NotImplementedError, models.Item.objects.all().aggregate, aggregate('last_modified')) + self.assertRaises( + NotImplementedError, + models.Item.objects.all().aggregate, + **{'complex': aggregate('last_modified') + aggregate('last_modified')}) @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")