Simplified a bit GeoAggregate classes

Thanks Josh Smeaton for the review. Refs #24152.
This commit is contained in:
Claude Paroz 2015-01-15 11:35:35 +01:00
parent 28db4af80a
commit a34fba5e59
8 changed files with 46 additions and 134 deletions

View File

@ -1,5 +1,7 @@
from functools import partial from functools import partial
from django.contrib.gis.db.models import aggregates
class BaseSpatialFeatures(object): class BaseSpatialFeatures(object):
gis_enabled = True gis_enabled = True
@ -61,15 +63,15 @@ class BaseSpatialFeatures(object):
# Specifies whether the Collect and Extent aggregates are supported by the database # Specifies whether the Collect and Extent aggregates are supported by the database
@property @property
def supports_collect_aggr(self): def supports_collect_aggr(self):
return 'Collect' in self.connection.ops.valid_aggregates return aggregates.Collect not in self.connection.ops.disallowed_aggregates
@property @property
def supports_extent_aggr(self): def supports_extent_aggr(self):
return 'Extent' in self.connection.ops.valid_aggregates return aggregates.Extent not in self.connection.ops.disallowed_aggregates
@property @property
def supports_make_line_aggr(self): def supports_make_line_aggr(self):
return 'MakeLine' in self.connection.ops.valid_aggregates return aggregates.MakeLine not in self.connection.ops.disallowed_aggregates
def __init__(self, *args): def __init__(self, *args):
super(BaseSpatialFeatures, self).__init__(*args) super(BaseSpatialFeatures, self).__init__(*args)

View File

@ -46,11 +46,7 @@ class BaseSpatialOperations(object):
union = False union = False
# Aggregates # Aggregates
collect = False disallowed_aggregates = ()
extent = False
extent3d = False
make_line = False
unionagg = False
# Serialization # Serialization
geohash = False geohash = False
@ -103,12 +99,13 @@ class BaseSpatialOperations(object):
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method') raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
def check_aggregate_support(self, aggregate): def check_aggregate_support(self, aggregate):
if aggregate.contains_aggregate == 'gis': if isinstance(aggregate, self.disallowed_aggregates):
return aggregate.name in self.valid_aggregates raise NotImplementedError(
return super(BaseSpatialOperations, self).check_aggregate_support(aggregate) "%s spatial aggregation is not supported by this database backend." % aggregate.name
)
super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
# Spatial SQL Construction def spatial_aggregate_name(self, agg_name):
def spatial_aggregate_sql(self, agg):
raise NotImplementedError('Aggregate support not implemented for this spatial backend.') raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
# Routines for getting the OGC-compliant models. # Routines for getting the OGC-compliant models.

View File

@ -1,6 +1,7 @@
from django.contrib.gis.db.backends.base.adapter import WKTAdapter from django.contrib.gis.db.backends.base.adapter import WKTAdapter
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.db.backends.mysql.operations import DatabaseOperations from django.db.backends.mysql.operations import DatabaseOperations
@ -30,6 +31,8 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
'within': SpatialOperator(func='MBRWithin'), 'within': SpatialOperator(func='MBRWithin'),
} }
disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)
def geo_db_type(self, f): def geo_db_type(self, f):
return f.geom_type return f.geom_type

View File

@ -12,6 +12,7 @@ import re
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
from django.db.backends.oracle.base import Database from django.db.backends.oracle.base import Database
@ -56,7 +57,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
name = 'oracle' name = 'oracle'
oracle = True oracle = True
valid_aggregates = {'Union', 'Extent'} disallowed_aggregates = (aggregates.Collect, aggregates.Extent3D, aggregates.MakeLine)
Adapter = OracleSpatialAdapter Adapter = OracleSpatialAdapter
Adaptor = Adapter # Backwards-compatibility alias. Adaptor = Adapter # Backwards-compatibility alias.
@ -223,20 +224,12 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
else: else:
return 'SDO_GEOMETRY(%%s, %s)' % f.srid return 'SDO_GEOMETRY(%%s, %s)' % f.srid
def spatial_aggregate_sql(self, agg): def spatial_aggregate_name(self, agg_name):
""" """
Returns the spatial aggregate SQL template and function for the Returns the spatial aggregate SQL name.
given Aggregate instance.
""" """
agg_name = agg.__class__.__name__.lower() agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower()
if agg_name == 'union': return getattr(self, agg_name)
agg_name += 'agg'
if agg.is_extent:
sql_template = '%(function)s(%(expressions)s)'
else:
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
sql_function = getattr(self, agg_name)
return sql_template, sql_function
# Routines for getting the OGC-compliant models. # Routines for getting the OGC-compliant models.
def geometry_columns(self): def geometry_columns(self):

View File

@ -49,7 +49,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
geography = True geography = True
geom_func_prefix = 'ST_' geom_func_prefix = 'ST_'
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)') version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
valid_aggregates = {'Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'}
Adapter = PostGISAdapter Adapter = PostGISAdapter
Adaptor = Adapter # Backwards-compatibility alias. Adaptor = Adapter # Backwards-compatibility alias.
@ -360,20 +359,11 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
else: else:
raise Exception('Could not determine PROJ.4 version from PostGIS.') raise Exception('Could not determine PROJ.4 version from PostGIS.')
def spatial_aggregate_sql(self, agg): def spatial_aggregate_name(self, agg_name):
""" if agg_name == 'Extent3D':
Returns the spatial aggregate SQL template and function for the return self.extent3d
given Aggregate instance. else:
""" return self.geom_func_prefix + agg_name
agg_name = agg.__class__.__name__
if not self.check_aggregate_support(agg):
raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name)
agg_name = agg_name.lower()
if agg_name == 'union':
agg_name += 'agg'
sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name)
return sql_template, sql_function
# Routines for getting the OGC-compliant models. # Routines for getting the OGC-compliant models.
def geometry_columns(self): def geometry_columns(self):

View File

@ -4,6 +4,7 @@ import sys
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -18,13 +19,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
spatialite = True spatialite = True
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)') version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
@property
def valid_aggregates(self):
if self.spatial_version >= (3, 0, 0):
return {'Collect', 'Extent', 'Union'}
else:
return {'Union'}
Adapter = SpatiaLiteAdapter Adapter = SpatiaLiteAdapter
Adaptor = Adapter # Backwards-compatibility alias. Adaptor = Adapter # Backwards-compatibility alias.
@ -109,6 +103,13 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
return False return False
return True return True
@cached_property
def disallowed_aggregates(self):
disallowed = (aggregates.Extent3D, aggregates.MakeLine)
if self.spatial_version < (3, 0, 0):
disallowed += (aggregates.Collect, aggregates.Extent)
return disallowed
@cached_property @cached_property
def gml(self): def gml(self):
return 'AsGML' if self._version_greater_2_4_0_rc4 else None return 'AsGML' if self._version_greater_2_4_0_rc4 else None
@ -237,20 +238,13 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
return (version, major, minor1, minor2) return (version, major, minor1, minor2)
def spatial_aggregate_sql(self, agg): def spatial_aggregate_name(self, agg_name):
""" """
Returns the spatial aggregate SQL template and function for the Returns the spatial aggregate SQL template and function for the
given Aggregate instance. given Aggregate instance.
""" """
agg_name = agg.__class__.__name__ agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower()
if not self.check_aggregate_support(agg): return getattr(self, agg_name)
raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name)
agg_name = agg_name.lower()
if agg_name == 'union':
agg_name += 'agg'
sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name)
return sql_template, sql_function
# Routines for getting the OGC-compliant models. # Routines for getting the OGC-compliant models.
def geometry_columns(self): def geometry_columns(self):

View File

@ -5,24 +5,21 @@ __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']
class GeoAggregate(Aggregate): class GeoAggregate(Aggregate):
template = None
function = None function = None
contains_aggregate = 'gis'
is_extent = False is_extent = False
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if connection.ops.oracle: self.function = connection.ops.spatial_aggregate_name(self.name)
if not hasattr(self, 'tolerance'):
self.tolerance = 0.05
self.extra['tolerance'] = self.tolerance
template, function = connection.ops.spatial_aggregate_sql(self)
if template is None:
template = '%(function)s(%(expressions)s)'
self.extra['template'] = self.extra.get('template', template)
self.extra['function'] = self.extra.get('function', function)
return super(GeoAggregate, self).as_sql(compiler, connection) return super(GeoAggregate, self).as_sql(compiler, connection)
def as_oracle(self, compiler, connection):
if not hasattr(self, 'tolerance'):
self.tolerance = 0.05
self.extra['tolerance'] = self.tolerance
if not self.is_extent:
self.template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
return self.as_sql(compiler, connection)
def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False): def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize) c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize)
if not isinstance(self.expressions[0].output_field, GeometryField): if not isinstance(self.expressions[0].output_field, GeometryField):

View File

@ -1,6 +1,5 @@
from django.db.models.sql import aggregates from django.db.models.sql import aggregates
from django.db.models.sql.aggregates import * # NOQA from django.db.models.sql.aggregates import * # NOQA
from django.contrib.gis.db.models.fields import GeometryField
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__ __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__
@ -10,66 +9,3 @@ warnings.warn(
"django.contrib.gis.db.models.sql.aggregates is deprecated. Use " "django.contrib.gis.db.models.sql.aggregates is deprecated. Use "
"django.contrib.gis.db.models.aggregates instead.", "django.contrib.gis.db.models.aggregates instead.",
RemovedInDjango20Warning, stacklevel=2) RemovedInDjango20Warning, stacklevel=2)
class GeoAggregate(Aggregate):
# Default SQL template for spatial aggregates.
sql_template = '%(function)s(%(expressions)s)'
# Flags for indicating the type of the aggregate.
is_extent = False
def __init__(self, col, source=None, is_summary=False, tolerance=0.05, **extra):
super(GeoAggregate, self).__init__(col, source, is_summary, **extra)
# Required by some Oracle aggregates.
self.tolerance = tolerance
# Can't use geographic aggregates on non-geometry fields.
if not isinstance(self.source, GeometryField):
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
def as_sql(self, compiler, connection):
"Return the aggregate, rendered as SQL with parameters."
if connection.ops.oracle:
self.extra['tolerance'] = self.tolerance
params = []
if hasattr(self.col, 'as_sql'):
field_name, params = self.col.as_sql(compiler, connection)
elif isinstance(self.col, (list, tuple)):
field_name = '.'.join(compiler.quote_name_unless_alias(c) for c in self.col)
else:
field_name = self.col
sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
substitutions = {
'function': sql_function,
'expressions': field_name
}
substitutions.update(self.extra)
return sql_template % substitutions, params
class Collect(GeoAggregate):
pass
class Extent(GeoAggregate):
is_extent = '2D'
class Extent3D(GeoAggregate):
is_extent = '3D'
class MakeLine(GeoAggregate):
pass
class Union(GeoAggregate):
pass