Fixed #28432 -- Allowed geometry expressions to be used with distance lookups.

Distance lookups use the Distance function for decreased code redundancy.
This commit is contained in:
Sergey Fedoseev 2017-07-24 23:19:28 +05:00 committed by Tim Graham
parent c7d58c6f43
commit 38af496b98
6 changed files with 74 additions and 69 deletions

View File

@ -1,3 +1,6 @@
from django.contrib.gis.db.models.functions import Distance
class BaseSpatialOperations: class BaseSpatialOperations:
# Quick booleans for the type of this spatial backend, and # Quick booleans for the type of this spatial backend, and
# an attribute for the spatial database version tuple (if applicable) # an attribute for the spatial database version tuple (if applicable)
@ -113,3 +116,5 @@ class BaseSpatialOperations:
def spatial_ref_sys(self): def spatial_ref_sys(self):
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method') raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
distance_expr_for_lookup = staticmethod(Distance)

View File

@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator):
sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'" sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'"
class SDODistance(SpatialOperator):
sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE
class SDODWithin(SpatialOperator): class SDODWithin(SpatialOperator):
sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'" sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'"
@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch' 'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch'
'touches': SDOOperator(func='SDO_TOUCH'), 'touches': SDOOperator(func='SDO_TOUCH'),
'within': SDOOperator(func='SDO_INSIDE'), 'within': SDOOperator(func='SDO_INSIDE'),
'distance_gt': SDODistance(op='>'),
'distance_gte': SDODistance(op='>='),
'distance_lt': SDODistance(op='<'),
'distance_lte': SDODistance(op='<='),
'dwithin': SDODWithin(), 'dwithin': SDODWithin(),
} }

View File

@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import (
BaseSpatialOperations, BaseSpatialOperations,
) )
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import GeometryField, RasterField
from django.contrib.gis.gdal import GDALRaster from django.contrib.gis.gdal import GDALRaster
from django.contrib.gis.measure import Distance from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.backends.postgresql.operations import DatabaseOperations from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.models import Func, Value
from django.db.utils import ProgrammingError from django.db.utils import ProgrammingError
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.version import get_version_tuple from django.utils.version import get_version_tuple
@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator):
return template_params return template_params
class PostGISDistanceOperator(PostGISOperator): class ST_Polygon(Func):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' function = 'ST_Polygon'
def as_sql(self, connection, lookup, template_params, sql_params): def __init__(self, expr):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): super().__init__(expr)
template_params = self.check_raster(lookup, template_params) expr = self.source_expressions[0]
sql_template = self.sql_template if isinstance(expr, Value) and not expr._output_field_or_none:
if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid': self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid))
template_params.update({
'op': self.op, @cached_property
'func': connection.ops.spatial_function_name('DistanceSpheroid'), def output_field(self):
}) return GeometryField(srid=self.source_expressions[0].field.srid)
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
# Using DistanceSpheroid requires the spheroid of the field as
# a parameter.
sql_params.insert(1, lookup.lhs.output_field.spheroid(connection))
else:
template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')})
return sql_template % template_params, sql_params
return super().as_sql(connection, lookup, template_params, sql_params)
class PostGISOperations(BaseSpatialOperations, DatabaseOperations): class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL), 'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
'within': PostGISOperator(func='ST_Within', raster=BILATERAL), 'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL), 'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
} }
unsupported_functions = set() unsupported_functions = set()
@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
def parse_raster(self, value): def parse_raster(self, value):
"""Convert a PostGIS HEX String into a dict readable by GDALRaster.""" """Convert a PostGIS HEX String into a dict readable by GDALRaster."""
return from_pgraster(value) return from_pgraster(value)
def distance_expr_for_lookup(self, lhs, rhs, **kwargs):
return super().distance_expr_for_lookup(
self._normalize_distance_lookup_arg(lhs),
self._normalize_distance_lookup_arg(rhs),
**kwargs
)
@staticmethod
def _normalize_distance_lookup_arg(arg):
is_raster = (
arg.field.geom_type == 'RASTER'
if hasattr(arg, 'field') else
isinstance(arg, GDALRaster)
)
return ST_Polygon(arg) if is_raster else arg

View File

@ -17,20 +17,6 @@ from django.utils.functional import cached_property
from django.utils.version import get_version_tuple from django.utils.version import get_version_tuple
class SpatiaLiteDistanceOperator(SpatialOperator):
def as_sql(self, connection, lookup, template_params, sql_params):
if lookup.lhs.output_field.geodetic(connection):
# SpatiaLite returns NULL instead of zero on geodetic coordinates
sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s'
template_params.update({
'op': self.op,
'func': connection.ops.spatial_function_name('Distance'),
})
sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid')
return sql_template % template_params, sql_params
return super().as_sql(connection, lookup, template_params, sql_params)
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
name = 'spatialite' name = 'spatialite'
spatialite = True spatialite = True
@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
'exact': SpatialOperator(func='Equals'), 'exact': SpatialOperator(func='Equals'),
# Distance predicates # Distance predicates
'dwithin': SpatialOperator(func='PtDistWithin'), 'dwithin': SpatialOperator(func='PtDistWithin'),
'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'),
'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='),
'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'),
'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='),
} }
disallowed_aggregates = (aggregates.Extent3D,) disallowed_aggregates = (aggregates.Extent3D,)

View File

@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup):
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid': if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
self.process_band_indices() self.process_band_indices()
def process_rhs(self, compiler, connection): def process_distance(self, compiler, connection):
params = [connection.ops.Adapter(self.rhs)]
# Getting the distance parameter in the units of the field.
dist_param = self.rhs_params[0] dist_param = self.rhs_params[0]
if hasattr(dist_param, 'resolve_expression'): return (
dist_param = dist_param.resolve_expression(compiler.query) compiler.compile(dist_param.resolve_expression(compiler.query))
sql, expr_params = compiler.compile(dist_param) if hasattr(dist_param, 'resolve_expression') else
self.template_params['value'] = sql ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
params.extend(expr_params)
else:
params += connection.ops.get_distance(
self.lhs.output_field, self.rhs_params,
self.lookup_name,
) )
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
return (rhs, params)
@BaseSpatialField.register_lookup @BaseSpatialField.register_lookup
@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase):
lookup_name = 'dwithin' lookup_name = 'dwithin'
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)' sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
def process_rhs(self, compiler, connection):
dist_sql, dist_params = self.process_distance(compiler, connection)
self.template_params['value'] = dist_sql
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
return rhs, [connection.ops.Adapter(self.rhs)] + dist_params
class DistanceLookupFromFunction(DistanceLookupBase):
def as_sql(self, compiler, connection):
spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
dist_sql, dist_params = self.process_distance(compiler, connection)
return (
'%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
params + dist_params,
)
@BaseSpatialField.register_lookup @BaseSpatialField.register_lookup
class DistanceGTLookup(DistanceLookupBase): class DistanceGTLookup(DistanceLookupFromFunction):
lookup_name = 'distance_gt' lookup_name = 'distance_gt'
op = '>'
@BaseSpatialField.register_lookup @BaseSpatialField.register_lookup
class DistanceGTELookup(DistanceLookupBase): class DistanceGTELookup(DistanceLookupFromFunction):
lookup_name = 'distance_gte' lookup_name = 'distance_gte'
op = '>='
@BaseSpatialField.register_lookup @BaseSpatialField.register_lookup
class DistanceLTLookup(DistanceLookupBase): class DistanceLTLookup(DistanceLookupFromFunction):
lookup_name = 'distance_lt' lookup_name = 'distance_lt'
op = '<'
@BaseSpatialField.register_lookup @BaseSpatialField.register_lookup
class DistanceLTELookup(DistanceLookupBase): class DistanceLTELookup(DistanceLookupFromFunction):
lookup_name = 'distance_lte' lookup_name = 'distance_lte'
op = '<='

View File

@ -1,5 +1,5 @@
from django.contrib.gis.db.models.functions import ( from django.contrib.gis.db.models.functions import (
Area, Distance, Length, Perimeter, Transform, Area, Distance, Intersection, Length, Perimeter, Transform,
) )
from django.contrib.gis.geos import GEOSGeometry, LineString, Point from django.contrib.gis.geos import GEOSGeometry, LineString, Point
from django.contrib.gis.measure import D # alias for Distance from django.contrib.gis.measure import D # alias for Distance
@ -206,6 +206,13 @@ class DistanceTest(TestCase):
).order_by('name') ).order_by('name')
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne']) self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
# With a complex geometry expression
self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0)))
self.assertEqual(
SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(),
SouthTexasCity.objects.count(),
)
''' '''
============================= =============================