Simplified a bit GeoAggregate classes

Thanks Josh Smeaton for the review. Refs #24152.
This commit is contained in:
Claude Paroz 2015-01-15 11:35:35 +01:00
parent 28db4af80a
commit a34fba5e59
8 changed files with 46 additions and 134 deletions

View File

@ -1,5 +1,7 @@
from functools import partial
from django.contrib.gis.db.models import aggregates
class BaseSpatialFeatures(object):
gis_enabled = True
@ -61,15 +63,15 @@ class BaseSpatialFeatures(object):
# Specifies whether the Collect and Extent aggregates are supported by the database
@property
def supports_collect_aggr(self):
return 'Collect' in self.connection.ops.valid_aggregates
return aggregates.Collect not in self.connection.ops.disallowed_aggregates
@property
def supports_extent_aggr(self):
return 'Extent' in self.connection.ops.valid_aggregates
return aggregates.Extent not in self.connection.ops.disallowed_aggregates
@property
def supports_make_line_aggr(self):
return 'MakeLine' in self.connection.ops.valid_aggregates
return aggregates.MakeLine not in self.connection.ops.disallowed_aggregates
def __init__(self, *args):
super(BaseSpatialFeatures, self).__init__(*args)

View File

@ -46,11 +46,7 @@ class BaseSpatialOperations(object):
union = False
# Aggregates
collect = False
extent = False
extent3d = False
make_line = False
unionagg = False
disallowed_aggregates = ()
# Serialization
geohash = False
@ -103,12 +99,13 @@ class BaseSpatialOperations(object):
raise NotImplementedError('subclasses of BaseSpatialOperations must provide a geo_db_placeholder() method')
def check_aggregate_support(self, aggregate):
if aggregate.contains_aggregate == 'gis':
return aggregate.name in self.valid_aggregates
return super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
if isinstance(aggregate, self.disallowed_aggregates):
raise NotImplementedError(
"%s spatial aggregation is not supported by this database backend." % aggregate.name
)
super(BaseSpatialOperations, self).check_aggregate_support(aggregate)
# Spatial SQL Construction
def spatial_aggregate_sql(self, agg):
def spatial_aggregate_name(self, agg_name):
raise NotImplementedError('Aggregate support not implemented for this spatial backend.')
# Routines for getting the OGC-compliant models.

View File

@ -1,6 +1,7 @@
from django.contrib.gis.db.backends.base.adapter import WKTAdapter
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.db.backends.mysql.operations import DatabaseOperations
@ -30,6 +31,8 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
'within': SpatialOperator(func='MBRWithin'),
}
disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)
def geo_db_type(self, f):
return f.geom_type

View File

@ -12,6 +12,7 @@ import re
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance
from django.db.backends.oracle.base import Database
@ -56,7 +57,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
name = 'oracle'
oracle = True
valid_aggregates = {'Union', 'Extent'}
disallowed_aggregates = (aggregates.Collect, aggregates.Extent3D, aggregates.MakeLine)
Adapter = OracleSpatialAdapter
Adaptor = Adapter # Backwards-compatibility alias.
@ -223,20 +224,12 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
else:
return 'SDO_GEOMETRY(%%s, %s)' % f.srid
def spatial_aggregate_sql(self, agg):
def spatial_aggregate_name(self, agg_name):
"""
Returns the spatial aggregate SQL template and function for the
given Aggregate instance.
Returns the spatial aggregate SQL name.
"""
agg_name = agg.__class__.__name__.lower()
if agg_name == 'union':
agg_name += 'agg'
if agg.is_extent:
sql_template = '%(function)s(%(expressions)s)'
else:
sql_template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
sql_function = getattr(self, agg_name)
return sql_template, sql_function
agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower()
return getattr(self, agg_name)
# Routines for getting the OGC-compliant models.
def geometry_columns(self):

View File

@ -49,7 +49,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
geography = True
geom_func_prefix = 'ST_'
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
valid_aggregates = {'Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'}
Adapter = PostGISAdapter
Adaptor = Adapter # Backwards-compatibility alias.
@ -360,20 +359,11 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
else:
raise Exception('Could not determine PROJ.4 version from PostGIS.')
def spatial_aggregate_sql(self, agg):
"""
Returns the spatial aggregate SQL template and function for the
given Aggregate instance.
"""
agg_name = agg.__class__.__name__
if not self.check_aggregate_support(agg):
raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name)
agg_name = agg_name.lower()
if agg_name == 'union':
agg_name += 'agg'
sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name)
return sql_template, sql_function
def spatial_aggregate_name(self, agg_name):
if agg_name == 'Extent3D':
return self.extent3d
else:
return self.geom_func_prefix + agg_name
# Routines for getting the OGC-compliant models.
def geometry_columns(self):

View File

@ -4,6 +4,7 @@ import sys
from django.contrib.gis.db.backends.base.operations import BaseSpatialOperations
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter
from django.contrib.gis.db.models import aggregates
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured
@ -18,13 +19,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
spatialite = True
version_regex = re.compile(r'^(?P<major>\d)\.(?P<minor1>\d)\.(?P<minor2>\d+)')
@property
def valid_aggregates(self):
if self.spatial_version >= (3, 0, 0):
return {'Collect', 'Extent', 'Union'}
else:
return {'Union'}
Adapter = SpatiaLiteAdapter
Adaptor = Adapter # Backwards-compatibility alias.
@ -109,6 +103,13 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
return False
return True
@cached_property
def disallowed_aggregates(self):
disallowed = (aggregates.Extent3D, aggregates.MakeLine)
if self.spatial_version < (3, 0, 0):
disallowed += (aggregates.Collect, aggregates.Extent)
return disallowed
@cached_property
def gml(self):
return 'AsGML' if self._version_greater_2_4_0_rc4 else None
@ -237,20 +238,13 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
return (version, major, minor1, minor2)
def spatial_aggregate_sql(self, agg):
def spatial_aggregate_name(self, agg_name):
"""
Returns the spatial aggregate SQL template and function for the
given Aggregate instance.
"""
agg_name = agg.__class__.__name__
if not self.check_aggregate_support(agg):
raise NotImplementedError('%s spatial aggregate is not implemented for this backend.' % agg_name)
agg_name = agg_name.lower()
if agg_name == 'union':
agg_name += 'agg'
sql_template = '%(function)s(%(expressions)s)'
sql_function = getattr(self, agg_name)
return sql_template, sql_function
agg_name = 'unionagg' if agg_name.lower() == 'union' else agg_name.lower()
return getattr(self, agg_name)
# Routines for getting the OGC-compliant models.
def geometry_columns(self):

View File

@ -5,24 +5,21 @@ __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union']
class GeoAggregate(Aggregate):
template = None
function = None
contains_aggregate = 'gis'
is_extent = False
def as_sql(self, compiler, connection):
if connection.ops.oracle:
if not hasattr(self, 'tolerance'):
self.tolerance = 0.05
self.extra['tolerance'] = self.tolerance
template, function = connection.ops.spatial_aggregate_sql(self)
if template is None:
template = '%(function)s(%(expressions)s)'
self.extra['template'] = self.extra.get('template', template)
self.extra['function'] = self.extra.get('function', function)
self.function = connection.ops.spatial_aggregate_name(self.name)
return super(GeoAggregate, self).as_sql(compiler, connection)
def as_oracle(self, compiler, connection):
if not hasattr(self, 'tolerance'):
self.tolerance = 0.05
self.extra['tolerance'] = self.tolerance
if not self.is_extent:
self.template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
return self.as_sql(compiler, connection)
def prepare(self, query=None, allow_joins=True, reuse=None, summarize=False):
c = super(GeoAggregate, self).prepare(query, allow_joins, reuse, summarize)
if not isinstance(self.expressions[0].output_field, GeometryField):

View File

@ -1,6 +1,5 @@
from django.db.models.sql import aggregates
from django.db.models.sql.aggregates import * # NOQA
from django.contrib.gis.db.models.fields import GeometryField
__all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] + aggregates.__all__
@ -10,66 +9,3 @@ warnings.warn(
"django.contrib.gis.db.models.sql.aggregates is deprecated. Use "
"django.contrib.gis.db.models.aggregates instead.",
RemovedInDjango20Warning, stacklevel=2)
class GeoAggregate(Aggregate):
# Default SQL template for spatial aggregates.
sql_template = '%(function)s(%(expressions)s)'
# Flags for indicating the type of the aggregate.
is_extent = False
def __init__(self, col, source=None, is_summary=False, tolerance=0.05, **extra):
super(GeoAggregate, self).__init__(col, source, is_summary, **extra)
# Required by some Oracle aggregates.
self.tolerance = tolerance
# Can't use geographic aggregates on non-geometry fields.
if not isinstance(self.source, GeometryField):
raise ValueError('Geospatial aggregates only allowed on geometry fields.')
def as_sql(self, compiler, connection):
"Return the aggregate, rendered as SQL with parameters."
if connection.ops.oracle:
self.extra['tolerance'] = self.tolerance
params = []
if hasattr(self.col, 'as_sql'):
field_name, params = self.col.as_sql(compiler, connection)
elif isinstance(self.col, (list, tuple)):
field_name = '.'.join(compiler.quote_name_unless_alias(c) for c in self.col)
else:
field_name = self.col
sql_template, sql_function = connection.ops.spatial_aggregate_sql(self)
substitutions = {
'function': sql_function,
'expressions': field_name
}
substitutions.update(self.extra)
return sql_template % substitutions, params
class Collect(GeoAggregate):
pass
class Extent(GeoAggregate):
is_extent = '2D'
class Extent3D(GeoAggregate):
is_extent = '3D'
class MakeLine(GeoAggregate):
pass
class Union(GeoAggregate):
pass