[1.8.x] Fixed #24154 -- Backends can now check support for expressions

Backport of 8196e4bdf4 from master
This commit is contained in:
Josh Smeaton 2015-01-17 16:03:46 +11:00
parent 5dff3513cc
commit e56810e839
11 changed files with 52 additions and 62 deletions

View File

@ -98,12 +98,12 @@ class BaseSpatialOperations(object):
"""
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
def check_aggregate_support(self, aggregate):
if isinstance(aggregate, self.disallowed_aggregates):
def check_expression_support(self, expression):
if isinstance(expression, self.disallowed_aggregates):
raise NotImplementedError(
"%s spatial aggregation is not supported by this database backend." % aggregate.name
"%s spatial aggregation is not supported by this database backend." % expression.name
)
super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
super(BaseSpatialOperations, self).check_expression_support(expression)
def spatial_aggregate_name(self, agg_name):
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')

View File

@ -9,6 +9,9 @@ class GeoAggregate(Aggregate):
is_extent = False
def as_sql(self, compiler, connection):
# this will be called again in parent, but it's needed now - before
# we get the spatial_aggregate_name
connection.ops.check_expression_support(self)
self.function = connection.ops.spatial_aggregate_name(self.name)
return super(GeoAggregate, self).as_sql(compiler, connection)

View File

@ -1,3 +1,5 @@
from django.db.models.aggregates import StdDev
from django.db.models.expressions import Value
from django.db.utils import ProgrammingError
from django.utils.functional import cached_property
@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object):
@cached_property
def supports_stddev(self):
"""Confirm support for STDDEV and related stats functions."""
class StdDevPop(object):
contains_aggregate = True
sql_function = 'STDDEV_POP'
try:
self.connection.ops.check_aggregate_support(StdDevPop())
self.connection.ops.check_expression_support(StdDev(Value(1)))
return True
except NotImplementedError:
return False

View File

@ -526,12 +526,16 @@ class BaseDatabaseOperations(object):
return value
def check_aggregate_support(self, aggregate_func):
"""Check that the backend supports the provided aggregate
return self.check_expression_support(aggregate_func)
This is used on specific backends to rule out known aggregates
that are known to have faulty implementations. If the named
aggregate function has a known problem, the backend should
raise NotImplementedError.
def check_expression_support(self, expression):
"""
Check that the backend supports the provided expression.
This is used on specific backends to rule out known expressions
that have problematic or nonexistent implementations. If the
expression has a known problem, the backend should raise
NotImplementedError.
"""
pass

View File

@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"""Confirm support for STDDEV and related stats functions
SQLite supports STDDEV as an extension package; so
connection.ops.check_aggregate_support() can't unilaterally
connection.ops.check_expression_support() can't unilaterally
rule out support for STDDEV. We need to manually check
whether the call works.
"""

View File

@ -4,7 +4,7 @@ import datetime
import uuid
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import ImproperlyConfigured, FieldError
from django.db import utils
from django.db.backends import utils as backend_utils
from django.db.backends.base.operations import BaseDatabaseOperations
@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations):
limit = 999 if len(fields) > 1 else 500
return (limit // len(fields)) if len(fields) > 0 else len(objs)
def check_aggregate_support(self, aggregate):
def check_expression_support(self, expression):
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
bad_aggregates = (aggregates.Sum, aggregates.Avg,
aggregates.Variance, aggregates.StdDev)
if aggregate.refs_field(bad_aggregates, bad_fields):
raise NotImplementedError(
'You cannot use Sum, Avg, StdDev and Variance aggregations '
'on date/time fields in sqlite3 '
'since date/time is saved as text.')
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
if isinstance(expression, bad_aggregates):
try:
output_field = expression.input_field.output_field
if isinstance(output_field, bad_fields):
raise NotImplementedError(
'You cannot use Sum, Avg, StdDev and Variance aggregations '
'on date/time fields in sqlite3 '
'since date/time is saved as text.')
except FieldError:
# not every sub-expression has an output_field which is fine to
# ignore
pass
def date_extract_sql(self, lookup_type, field_name):
# sqlite doesn't support extract, so we fake it with the user-defined

View File

@ -25,17 +25,6 @@ class Aggregate(Func):
c._patch_aggregate(query) # backward-compatibility support
return c
def refs_field(self, aggregate_types, field_types):
try:
return (isinstance(self, aggregate_types) and
isinstance(self.input_field._output_field_or_none, field_types))
except FieldError:
# Sometimes we don't know the input_field's output type (for example,
# doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
# is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
# have any output field.
return False
@property
def input_field(self):
return self.source_expressions[0]

View File

@ -297,14 +297,6 @@ class BaseExpression(object):
return agg, lookup
return False, ()
def refs_field(self, aggregate_types, field_types):
"""
Helper method for check_aggregate_support on backends
"""
return any(
node.refs_field(aggregate_types, field_types)
for node in self.get_source_expressions())
def prepare_database_save(self, field):
return self
@ -401,6 +393,7 @@ class DurationExpression(Expression):
return compiler.compile(side)
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
expressions = []
expression_params = []
sql, params = self.compile(self.lhs, compiler, connection)
@ -473,6 +466,7 @@ class Func(ExpressionNode):
return c
def as_sql(self, compiler, connection, function=None, template=None):
connection.ops.check_expression_support(self)
sql_parts = []
params = []
for arg in self.source_expressions:
@ -511,6 +505,7 @@ class Value(ExpressionNode):
self.value = value
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
val = self.value
# check _output_field to avoid triggering an exception
if self._output_field is not None:
@ -536,6 +531,7 @@ class Value(ExpressionNode):
class DurationValue(Value):
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
if (connection.features.has_native_duration_field and
connection.features.driver_supports_timedelta_args):
return super(DurationValue, self).as_sql(compiler, connection)
@ -650,6 +646,7 @@ class When(ExpressionNode):
return c
def as_sql(self, compiler, connection, template=None):
connection.ops.check_expression_support(self)
template_params = {}
sql_params = []
condition_sql, condition_params = compiler.compile(self.condition)
@ -715,6 +712,7 @@ class Case(ExpressionNode):
return c
def as_sql(self, compiler, connection, template=None, extra=None):
connection.ops.check_expression_support(self)
if not self.cases:
return compiler.compile(self.default)
template_params = dict(extra) if extra else {}
@ -851,6 +849,7 @@ class OrderBy(BaseExpression):
return [self.expression]
def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
expression_sql, params = compiler.compile(self.expression)
placeholders = {'expression': expression_sql}
placeholders['ordering'] = 'DESC' if self.descending else 'ASC'

View File

@ -230,11 +230,6 @@ class Query(object):
raise ValueError("Need either using or connection")
if using:
connection = connections[using]
# Check that the compiler will be able to execute the query
for alias, annotation in self.annotation_select.items():
connection.ops.check_aggregate_support(annotation)
return connection.ops.compiler(self.compiler)(self, connection, using)
def get_meta(self):

View File

@ -325,17 +325,6 @@ class WhereNode(tree.Node):
def contains_aggregate(self):
return self._contains_aggregate(self)
@classmethod
def _refs_field(cls, obj, aggregate_types, field_types):
if not isinstance(obj, tree.Node):
if hasattr(obj.rhs, 'refs_field'):
return obj.rhs.refs_field(aggregate_types, field_types)
return False
return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children)
def refs_field(self, aggregate_types, field_types):
return self._refs_field(self, aggregate_types, field_types)
class EmptyWhere(WhereNode):
def add(self, data, connector):

View File

@ -128,12 +128,19 @@ class SQLiteTests(TestCase):
#19360: Raise NotImplementedError when aggregating on date/time fields.
"""
for aggregate in (Sum, Avg, Variance, StdDev):
self.assertRaises(NotImplementedError,
self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate, aggregate('time'))
self.assertRaises(NotImplementedError,
self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate, aggregate('date'))
self.assertRaises(NotImplementedError,
self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate, aggregate('last_modified'))
self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate,
**{'complex': aggregate('last_modified') + aggregate('last_modified')})
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")