[1.8.x] Fixed #24154 -- Backends can now check support for expressions

Backport of 8196e4bdf4 from master
This commit is contained in:
Josh Smeaton 2015-01-17 16:03:46 +11:00
parent 5dff3513cc
commit e56810e839
11 changed files with 52 additions and 62 deletions

View File

@ -98,12 +98,12 @@ class BaseSpatialOperations(object):
""" """
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
def check_aggregate_support(self, aggregate): def check_expression_support(self, expression):
if isinstance(aggregate, self.disallowed_aggregates): if isinstance(expression, self.disallowed_aggregates):
raise NotImplementedError( raise NotImplementedError(
"%s spatial aggregation is not supported by this database backend." % aggregate.name "%s spatial aggregation is not supported by this database backend." % expression.name
) )
super(BaseSpatialOperations, self).check_aggregate_support(aggregate) super(BaseSpatialOperations, self).check_expression_support(expression)
def spatial_aggregate_name(self, agg_name): def spatial_aggregate_name(self, agg_name):
raise NotImplementedError('Aggregate support not implemented for this spatial backend.') raise NotImplementedError('Aggregate support not implemented for this spatial backend.')

View File

@ -9,6 +9,9 @@ class GeoAggregate(Aggregate):
is_extent = False is_extent = False
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
# this will be called again in parent, but it's needed now - before
# we get the spatial_aggregate_name
connection.ops.check_expression_support(self)
self.function = connection.ops.spatial_aggregate_name(self.name) self.function = connection.ops.spatial_aggregate_name(self.name)
return super(GeoAggregate, self).as_sql(compiler, connection) return super(GeoAggregate, self).as_sql(compiler, connection)

View File

@ -1,3 +1,5 @@
from django.db.models.aggregates import StdDev
from django.db.models.expressions import Value
from django.db.utils import ProgrammingError from django.db.utils import ProgrammingError
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object):
@cached_property @cached_property
def supports_stddev(self): def supports_stddev(self):
"""Confirm support for STDDEV and related stats functions.""" """Confirm support for STDDEV and related stats functions."""
class StdDevPop(object):
contains_aggregate = True
sql_function = 'STDDEV_POP'
try: try:
self.connection.ops.check_aggregate_support(StdDevPop()) self.connection.ops.check_expression_support(StdDev(Value(1)))
return True return True
except NotImplementedError: except NotImplementedError:
return False return False

View File

@ -526,12 +526,16 @@ class BaseDatabaseOperations(object):
return value return value
def check_aggregate_support(self, aggregate_func): def check_aggregate_support(self, aggregate_func):
"""Check that the backend supports the provided aggregate return self.check_expression_support(aggregate_func)
This is used on specific backends to rule out known aggregates def check_expression_support(self, expression):
that are known to have faulty implementations. If the named """
aggregate function has a known problem, the backend should Check that the backend supports the provided expression.
raise NotImplementedError.
This is used on specific backends to rule out known expressions
that have problematic or nonexistent implementations. If the
expression has a known problem, the backend should raise
NotImplementedError.
""" """
pass pass

View File

@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"""Confirm support for STDDEV and related stats functions """Confirm support for STDDEV and related stats functions
SQLite supports STDDEV as an extension package; so SQLite supports STDDEV as an extension package; so
connection.ops.check_aggregate_support() can't unilaterally connection.ops.check_expression_support() can't unilaterally
rule out support for STDDEV. We need to manually check rule out support for STDDEV. We need to manually check
whether the call works. whether the call works.
""" """

View File

@ -4,7 +4,7 @@ import datetime
import uuid import uuid
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured, FieldError
from django.db import utils from django.db import utils
from django.db.backends import utils as backend_utils from django.db.backends import utils as backend_utils
from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.base.operations import BaseDatabaseOperations
@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations):
limit = 999 if len(fields) > 1 else 500 limit = 999 if len(fields) > 1 else 500
return (limit // len(fields)) if len(fields) > 0 else len(objs) return (limit // len(fields)) if len(fields) > 0 else len(objs)
def check_aggregate_support(self, aggregate): def check_expression_support(self, expression):
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField) bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
bad_aggregates = (aggregates.Sum, aggregates.Avg, bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
aggregates.Variance, aggregates.StdDev) if isinstance(expression, bad_aggregates):
if aggregate.refs_field(bad_aggregates, bad_fields): try:
raise NotImplementedError( output_field = expression.input_field.output_field
'You cannot use Sum, Avg, StdDev and Variance aggregations ' if isinstance(output_field, bad_fields):
'on date/time fields in sqlite3 ' raise NotImplementedError(
'since date/time is saved as text.') 'You cannot use Sum, Avg, StdDev and Variance aggregations '
'on date/time fields in sqlite3 '
'since date/time is saved as text.')
except FieldError:
# not every sub-expression has an output_field which is fine to
# ignore
pass
def date_extract_sql(self, lookup_type, field_name): def date_extract_sql(self, lookup_type, field_name):
# sqlite doesn't support extract, so we fake it with the user-defined # sqlite doesn't support extract, so we fake it with the user-defined

View File

@ -25,17 +25,6 @@ class Aggregate(Func):
c._patch_aggregate(query) # backward-compatibility support c._patch_aggregate(query) # backward-compatibility support
return c return c
def refs_field(self, aggregate_types, field_types):
try:
return (isinstance(self, aggregate_types) and
isinstance(self.input_field._output_field_or_none, field_types))
except FieldError:
# Sometimes we don't know the input_field's output type (for example,
# doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
# is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
# have any output field.
return False
@property @property
def input_field(self): def input_field(self):
return self.source_expressions[0] return self.source_expressions[0]

View File

@ -297,14 +297,6 @@ class BaseExpression(object):
return agg, lookup return agg, lookup
return False, () return False, ()
def refs_field(self, aggregate_types, field_types):
"""
Helper method for check_aggregate_support on backends
"""
return any(
node.refs_field(aggregate_types, field_types)
for node in self.get_source_expressions())
def prepare_database_save(self, field): def prepare_database_save(self, field):
return self return self
@ -401,6 +393,7 @@ class DurationExpression(Expression):
return compiler.compile(side) return compiler.compile(side)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
expressions = [] expressions = []
expression_params = [] expression_params = []
sql, params = self.compile(self.lhs, compiler, connection) sql, params = self.compile(self.lhs, compiler, connection)
@ -473,6 +466,7 @@ class Func(ExpressionNode):
return c return c
def as_sql(self, compiler, connection, function=None, template=None): def as_sql(self, compiler, connection, function=None, template=None):
connection.ops.check_expression_support(self)
sql_parts = [] sql_parts = []
params = [] params = []
for arg in self.source_expressions: for arg in self.source_expressions:
@ -511,6 +505,7 @@ class Value(ExpressionNode):
self.value = value self.value = value
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
val = self.value val = self.value
# check _output_field to avoid triggering an exception # check _output_field to avoid triggering an exception
if self._output_field is not None: if self._output_field is not None:
@ -536,6 +531,7 @@ class Value(ExpressionNode):
class DurationValue(Value): class DurationValue(Value):
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
if (connection.features.has_native_duration_field and if (connection.features.has_native_duration_field and
connection.features.driver_supports_timedelta_args): connection.features.driver_supports_timedelta_args):
return super(DurationValue, self).as_sql(compiler, connection) return super(DurationValue, self).as_sql(compiler, connection)
@ -650,6 +646,7 @@ class When(ExpressionNode):
return c return c
def as_sql(self, compiler, connection, template=None): def as_sql(self, compiler, connection, template=None):
connection.ops.check_expression_support(self)
template_params = {} template_params = {}
sql_params = [] sql_params = []
condition_sql, condition_params = compiler.compile(self.condition) condition_sql, condition_params = compiler.compile(self.condition)
@ -715,6 +712,7 @@ class Case(ExpressionNode):
return c return c
def as_sql(self, compiler, connection, template=None, extra=None): def as_sql(self, compiler, connection, template=None, extra=None):
connection.ops.check_expression_support(self)
if not self.cases: if not self.cases:
return compiler.compile(self.default) return compiler.compile(self.default)
template_params = dict(extra) if extra else {} template_params = dict(extra) if extra else {}
@ -851,6 +849,7 @@ class OrderBy(BaseExpression):
return [self.expression] return [self.expression]
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
connection.ops.check_expression_support(self)
expression_sql, params = compiler.compile(self.expression) expression_sql, params = compiler.compile(self.expression)
placeholders = {'expression': expression_sql} placeholders = {'expression': expression_sql}
placeholders['ordering'] = 'DESC' if self.descending else 'ASC' placeholders['ordering'] = 'DESC' if self.descending else 'ASC'

View File

@ -230,11 +230,6 @@ class Query(object):
raise ValueError("Need either using or connection") raise ValueError("Need either using or connection")
if using: if using:
connection = connections[using] connection = connections[using]
# Check that the compiler will be able to execute the query
for alias, annotation in self.annotation_select.items():
connection.ops.check_aggregate_support(annotation)
return connection.ops.compiler(self.compiler)(self, connection, using) return connection.ops.compiler(self.compiler)(self, connection, using)
def get_meta(self): def get_meta(self):

View File

@ -325,17 +325,6 @@ class WhereNode(tree.Node):
def contains_aggregate(self): def contains_aggregate(self):
return self._contains_aggregate(self) return self._contains_aggregate(self)
@classmethod
def _refs_field(cls, obj, aggregate_types, field_types):
if not isinstance(obj, tree.Node):
if hasattr(obj.rhs, 'refs_field'):
return obj.rhs.refs_field(aggregate_types, field_types)
return False
return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children)
def refs_field(self, aggregate_types, field_types):
return self._refs_field(self, aggregate_types, field_types)
class EmptyWhere(WhereNode): class EmptyWhere(WhereNode):
def add(self, data, connector): def add(self, data, connector):

View File

@ -128,12 +128,19 @@ class SQLiteTests(TestCase):
#19360: Raise NotImplementedError when aggregating on date/time fields. #19360: Raise NotImplementedError when aggregating on date/time fields.
""" """
for aggregate in (Sum, Avg, Variance, StdDev): for aggregate in (Sum, Avg, Variance, StdDev):
self.assertRaises(NotImplementedError, self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate, aggregate('time')) models.Item.objects.all().aggregate, aggregate('time'))
self.assertRaises(NotImplementedError, self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate, aggregate('date')) models.Item.objects.all().aggregate, aggregate('date'))
self.assertRaises(NotImplementedError, self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate, aggregate('last_modified')) models.Item.objects.all().aggregate, aggregate('last_modified'))
self.assertRaises(
NotImplementedError,
models.Item.objects.all().aggregate,
**{'complex': aggregate('last_modified') + aggregate('last_modified')})
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL") @unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")