Fixed #24154 -- Backends can now check support for expressions
This commit is contained in:
parent
511be35779
commit
8196e4bdf4
|
@ -98,12 +98,12 @@ class BaseSpatialOperations(object):
|
|||
"""
|
||||
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
|
||||
|
||||
def check_aggregate_support(self, aggregate):
|
||||
if isinstance(aggregate, self.disallowed_aggregates):
|
||||
def check_expression_support(self, expression):
|
||||
if isinstance(expression, self.disallowed_aggregates):
|
||||
raise NotImplementedError(
|
||||
"%s spatial aggregation is not supported by this database backend." % aggregate.name
|
||||
"%s spatial aggregation is not supported by this database backend." % expression.name
|
||||
)
|
||||
super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
|
||||
super(BaseSpatialOperations, self).check_expression_support(expression)
|
||||
|
||||
def spatial_aggregate_name(self, agg_name):
|
||||
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
|
||||
|
|
|
@ -9,6 +9,9 @@ class GeoAggregate(Aggregate):
|
|||
is_extent = False
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# this will be called again in parent, but it's needed now - before
|
||||
# we get the spatial_aggregate_name
|
||||
connection.ops.check_expression_support(self)
|
||||
self.function = connection.ops.spatial_aggregate_name(self.name)
|
||||
return super(GeoAggregate, self).as_sql(compiler, connection)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from django.db.models.aggregates import StdDev
|
||||
from django.db.models.expressions import Value
|
||||
from django.db.utils import ProgrammingError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
@ -226,12 +228,8 @@ class BaseDatabaseFeatures(object):
|
|||
@cached_property
|
||||
def supports_stddev(self):
|
||||
"""Confirm support for STDDEV and related stats functions."""
|
||||
class StdDevPop(object):
|
||||
contains_aggregate = True
|
||||
sql_function = 'STDDEV_POP'
|
||||
|
||||
try:
|
||||
self.connection.ops.check_aggregate_support(StdDevPop())
|
||||
self.connection.ops.check_expression_support(StdDev(Value(1)))
|
||||
return True
|
||||
except NotImplementedError:
|
||||
return False
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import datetime
|
||||
import decimal
|
||||
from importlib import import_module
|
||||
import warnings
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends import utils
|
||||
from django.utils import six, timezone
|
||||
from django.utils.dateparse import parse_duration
|
||||
from django.utils.deprecation import RemovedInDjango21Warning
|
||||
from django.utils.encoding import force_text
|
||||
|
||||
|
||||
|
@ -517,12 +519,20 @@ class BaseDatabaseOperations(object):
|
|||
return value
|
||||
|
||||
def check_aggregate_support(self, aggregate_func):
|
||||
"""Check that the backend supports the provided aggregate
|
||||
warnings.warn(
|
||||
"check_aggregate_support has been deprecated. Use "
|
||||
"check_expression_support instead.",
|
||||
RemovedInDjango21Warning, stacklevel=2)
|
||||
return self.check_expression_support(aggregate_func)
|
||||
|
||||
This is used on specific backends to rule out known aggregates
|
||||
that are known to have faulty implementations. If the named
|
||||
aggregate function has a known problem, the backend should
|
||||
raise NotImplementedError.
|
||||
def check_expression_support(self, expression):
|
||||
"""
|
||||
Check that the backend supports the provided expression.
|
||||
|
||||
This is used on specific backends to rule out known expressions
|
||||
that have problematic or nonexistent implementations. If the
|
||||
expression has a known problem, the backend should raise
|
||||
NotImplementedError.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
|
|||
"""Confirm support for STDDEV and related stats functions
|
||||
|
||||
SQLite supports STDDEV as an extension package; so
|
||||
connection.ops.check_aggregate_support() can't unilaterally
|
||||
connection.ops.check_expression_support() can't unilaterally
|
||||
rule out support for STDDEV. We need to manually check
|
||||
whether the call works.
|
||||
"""
|
||||
|
|
|
@ -4,7 +4,7 @@ import datetime
|
|||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.exceptions import ImproperlyConfigured, FieldError
|
||||
from django.db import utils
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
|
@ -33,15 +33,21 @@ class DatabaseOperations(BaseDatabaseOperations):
|
|||
limit = 999 if len(fields) > 1 else 500
|
||||
return (limit // len(fields)) if len(fields) > 0 else len(objs)
|
||||
|
||||
def check_aggregate_support(self, aggregate):
|
||||
def check_expression_support(self, expression):
|
||||
bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
|
||||
bad_aggregates = (aggregates.Sum, aggregates.Avg,
|
||||
aggregates.Variance, aggregates.StdDev)
|
||||
if aggregate.refs_field(bad_aggregates, bad_fields):
|
||||
raise NotImplementedError(
|
||||
'You cannot use Sum, Avg, StdDev and Variance aggregations '
|
||||
'on date/time fields in sqlite3 '
|
||||
'since date/time is saved as text.')
|
||||
bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
|
||||
if isinstance(expression, bad_aggregates):
|
||||
try:
|
||||
output_field = expression.input_field.output_field
|
||||
if isinstance(output_field, bad_fields):
|
||||
raise NotImplementedError(
|
||||
'You cannot use Sum, Avg, StdDev and Variance aggregations '
|
||||
'on date/time fields in sqlite3 '
|
||||
'since date/time is saved as text.')
|
||||
except FieldError:
|
||||
# not every sub-expression has an output_field which is fine to
|
||||
# ignore
|
||||
pass
|
||||
|
||||
def date_extract_sql(self, lookup_type, field_name):
|
||||
# sqlite doesn't support extract, so we fake it with the user-defined
|
||||
|
|
|
@ -25,17 +25,6 @@ class Aggregate(Func):
|
|||
c._patch_aggregate(query) # backward-compatibility support
|
||||
return c
|
||||
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
try:
|
||||
return (isinstance(self, aggregate_types) and
|
||||
isinstance(self.input_field._output_field_or_none, field_types))
|
||||
except FieldError:
|
||||
# Sometimes we don't know the input_field's output type (for example,
|
||||
# doing Sum(F('datetimefield') + F('datefield'), output_type=DateTimeField())
|
||||
# is OK, but the Expression(F('datetimefield') + F('datefield')) doesn't
|
||||
# have any output field.
|
||||
return False
|
||||
|
||||
@property
|
||||
def input_field(self):
|
||||
return self.source_expressions[0]
|
||||
|
|
|
@ -297,14 +297,6 @@ class BaseExpression(object):
|
|||
return agg, lookup
|
||||
return False, ()
|
||||
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
"""
|
||||
Helper method for check_aggregate_support on backends
|
||||
"""
|
||||
return any(
|
||||
node.refs_field(aggregate_types, field_types)
|
||||
for node in self.get_source_expressions())
|
||||
|
||||
def prepare_database_save(self, field):
|
||||
return self
|
||||
|
||||
|
@ -401,6 +393,7 @@ class DurationExpression(Expression):
|
|||
return compiler.compile(side)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
expressions = []
|
||||
expression_params = []
|
||||
sql, params = self.compile(self.lhs, compiler, connection)
|
||||
|
@ -473,6 +466,7 @@ class Func(ExpressionNode):
|
|||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
connection.ops.check_expression_support(self)
|
||||
sql_parts = []
|
||||
params = []
|
||||
for arg in self.source_expressions:
|
||||
|
@ -511,6 +505,7 @@ class Value(ExpressionNode):
|
|||
self.value = value
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
val = self.value
|
||||
# check _output_field to avoid triggering an exception
|
||||
if self._output_field is not None:
|
||||
|
@ -536,6 +531,7 @@ class Value(ExpressionNode):
|
|||
|
||||
class DurationValue(Value):
|
||||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
if (connection.features.has_native_duration_field and
|
||||
connection.features.driver_supports_timedelta_args):
|
||||
return super(DurationValue, self).as_sql(compiler, connection)
|
||||
|
@ -650,6 +646,7 @@ class When(ExpressionNode):
|
|||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
connection.ops.check_expression_support(self)
|
||||
template_params = {}
|
||||
sql_params = []
|
||||
condition_sql, condition_params = compiler.compile(self.condition)
|
||||
|
@ -715,6 +712,7 @@ class Case(ExpressionNode):
|
|||
return c
|
||||
|
||||
def as_sql(self, compiler, connection, template=None, extra=None):
|
||||
connection.ops.check_expression_support(self)
|
||||
if not self.cases:
|
||||
return compiler.compile(self.default)
|
||||
template_params = dict(extra) if extra else {}
|
||||
|
@ -851,6 +849,7 @@ class OrderBy(BaseExpression):
|
|||
return [self.expression]
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
connection.ops.check_expression_support(self)
|
||||
expression_sql, params = compiler.compile(self.expression)
|
||||
placeholders = {'expression': expression_sql}
|
||||
placeholders['ordering'] = 'DESC' if self.descending else 'ASC'
|
||||
|
|
|
@ -230,11 +230,6 @@ class Query(object):
|
|||
raise ValueError("Need either using or connection")
|
||||
if using:
|
||||
connection = connections[using]
|
||||
|
||||
# Check that the compiler will be able to execute the query
|
||||
for alias, annotation in self.annotation_select.items():
|
||||
connection.ops.check_aggregate_support(annotation)
|
||||
|
||||
return connection.ops.compiler(self.compiler)(self, connection, using)
|
||||
|
||||
def get_meta(self):
|
||||
|
|
|
@ -153,17 +153,6 @@ class WhereNode(tree.Node):
|
|||
def contains_aggregate(self):
|
||||
return self._contains_aggregate(self)
|
||||
|
||||
@classmethod
|
||||
def _refs_field(cls, obj, aggregate_types, field_types):
|
||||
if not isinstance(obj, tree.Node):
|
||||
if hasattr(obj.rhs, 'refs_field'):
|
||||
return obj.rhs.refs_field(aggregate_types, field_types)
|
||||
return False
|
||||
return any(cls._refs_field(c, aggregate_types, field_types) for c in obj.children)
|
||||
|
||||
def refs_field(self, aggregate_types, field_types):
|
||||
return self._refs_field(self, aggregate_types, field_types)
|
||||
|
||||
|
||||
class EmptyWhere(WhereNode):
|
||||
def add(self, data, connector):
|
||||
|
|
|
@ -127,12 +127,19 @@ class SQLiteTests(TestCase):
|
|||
#19360: Raise NotImplementedError when aggregating on date/time fields.
|
||||
"""
|
||||
for aggregate in (Sum, Avg, Variance, StdDev):
|
||||
self.assertRaises(NotImplementedError,
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
models.Item.objects.all().aggregate, aggregate('time'))
|
||||
self.assertRaises(NotImplementedError,
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
models.Item.objects.all().aggregate, aggregate('date'))
|
||||
self.assertRaises(NotImplementedError,
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
models.Item.objects.all().aggregate, aggregate('last_modified'))
|
||||
self.assertRaises(
|
||||
NotImplementedError,
|
||||
models.Item.objects.all().aggregate,
|
||||
**{'complex': aggregate('last_modified') + aggregate('last_modified')})
|
||||
|
||||
|
||||
@unittest.skipUnless(connection.vendor == 'postgresql', "Test only for PostgreSQL")
|
||||
|
|
Loading…
Reference in New Issue