Fixed #28432 -- Allowed geometry expressions to be used with distance lookups.

Distance lookups use the Distance function for decreased code redundancy.
This commit is contained in:
Sergey Fedoseev 2017-07-24 23:19:28 +05:00 committed by Tim Graham
parent c7d58c6f43
commit 38af496b98
6 changed files with 74 additions and 69 deletions

View File

@ -1,3 +1,6 @@
from django.contrib.gis.db.models.functions import Distance
class BaseSpatialOperations:
# Quick booleans for the type of this spatial backend, and
# an attribute for the spatial database version tuple (if applicable)
@ -113,3 +116,5 @@ class BaseSpatialOperations:
def spatial_ref_sys(self):
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
distance_expr_for_lookup = staticmethod(Distance)

View File

@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator):
sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'"
class SDODistance(SpatialOperator):
sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE
class SDODWithin(SpatialOperator):
sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'"
@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch'
'touches': SDOOperator(func='SDO_TOUCH'),
'within': SDOOperator(func='SDO_INSIDE'),
'distance_gt': SDODistance(op='>'),
'distance_gte': SDODistance(op='>='),
'distance_lt': SDODistance(op='<'),
'distance_lte': SDODistance(op='<='),
'dwithin': SDODWithin(),
}

View File

@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import (
BaseSpatialOperations,
)
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import GeometryField, RasterField
from django.contrib.gis.gdal import GDALRaster
from django.contrib.gis.measure import Distance
from django.core.exceptions import ImproperlyConfigured
from django.db.backends.postgresql.operations import DatabaseOperations
from django.db.models import Func, Value
from django.db.utils import ProgrammingError
from django.utils.functional import cached_property
from django.utils.version import get_version_tuple
@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator):
return template_params
class PostGISDistanceOperator(PostGISOperator):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
class ST_Polygon(Func):
function = 'ST_Polygon'
def as_sql(self, connection, lookup, template_params, sql_params):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
template_params = self.check_raster(lookup, template_params)
sql_template = self.sql_template
if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid':
template_params.update({
'op': self.op,
'func': connection.ops.spatial_function_name('DistanceSpheroid'),
})
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
# Using DistanceSpheroid requires the spheroid of the field as
# a parameter.
sql_params.insert(1, lookup.lhs.output_field.spheroid(connection))
else:
template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')})
return sql_template % template_params, sql_params
return super().as_sql(connection, lookup, template_params, sql_params)
def __init__(self, expr):
super().__init__(expr)
expr = self.source_expressions[0]
if isinstance(expr, Value) and not expr._output_field_or_none:
self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid))
@cached_property
def output_field(self):
return GeometryField(srid=self.source_expressions[0].field.srid)
class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
}
unsupported_functions = set()
@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
def parse_raster(self, value):
"""Convert a PostGIS HEX String into a dict readable by GDALRaster."""
return from_pgraster(value)
def distance_expr_for_lookup(self, lhs, rhs, **kwargs):
return super().distance_expr_for_lookup(
self._normalize_distance_lookup_arg(lhs),
self._normalize_distance_lookup_arg(rhs),
**kwargs
)
@staticmethod
def _normalize_distance_lookup_arg(arg):
is_raster = (
arg.field.geom_type == 'RASTER'
if hasattr(arg, 'field') else
isinstance(arg, GDALRaster)
)
return ST_Polygon(arg) if is_raster else arg

View File

@ -17,20 +17,6 @@ from django.utils.functional import cached_property
from django.utils.version import get_version_tuple
class SpatiaLiteDistanceOperator(SpatialOperator):
def as_sql(self, connection, lookup, template_params, sql_params):
if lookup.lhs.output_field.geodetic(connection):
# SpatiaLite returns NULL instead of zero on geodetic coordinates
sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s'
template_params.update({
'op': self.op,
'func': connection.ops.spatial_function_name('Distance'),
})
sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid')
return sql_template % template_params, sql_params
return super().as_sql(connection, lookup, template_params, sql_params)
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
name = 'spatialite'
spatialite = True
@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
'exact': SpatialOperator(func='Equals'),
# Distance predicates
'dwithin': SpatialOperator(func='PtDistWithin'),
'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'),
'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='),
'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'),
'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='),
}
disallowed_aggregates = (aggregates.Extent3D,)

View File

@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup):
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
self.process_band_indices()
def process_rhs(self, compiler, connection):
params = [connection.ops.Adapter(self.rhs)]
# Getting the distance parameter in the units of the field.
def process_distance(self, compiler, connection):
dist_param = self.rhs_params[0]
if hasattr(dist_param, 'resolve_expression'):
dist_param = dist_param.resolve_expression(compiler.query)
sql, expr_params = compiler.compile(dist_param)
self.template_params['value'] = sql
params.extend(expr_params)
else:
params += connection.ops.get_distance(
self.lhs.output_field, self.rhs_params,
self.lookup_name,
)
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
return (rhs, params)
return (
compiler.compile(dist_param.resolve_expression(compiler.query))
if hasattr(dist_param, 'resolve_expression') else
('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
)
@BaseSpatialField.register_lookup
@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase):
lookup_name = 'dwithin'
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
def process_rhs(self, compiler, connection):
dist_sql, dist_params = self.process_distance(compiler, connection)
self.template_params['value'] = dist_sql
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
return rhs, [connection.ops.Adapter(self.rhs)] + dist_params
class DistanceLookupFromFunction(DistanceLookupBase):
def as_sql(self, compiler, connection):
spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
dist_sql, dist_params = self.process_distance(compiler, connection)
return (
'%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
params + dist_params,
)
@BaseSpatialField.register_lookup
class DistanceGTLookup(DistanceLookupBase):
class DistanceGTLookup(DistanceLookupFromFunction):
lookup_name = 'distance_gt'
op = '>'
@BaseSpatialField.register_lookup
class DistanceGTELookup(DistanceLookupBase):
class DistanceGTELookup(DistanceLookupFromFunction):
lookup_name = 'distance_gte'
op = '>='
@BaseSpatialField.register_lookup
class DistanceLTLookup(DistanceLookupBase):
class DistanceLTLookup(DistanceLookupFromFunction):
lookup_name = 'distance_lt'
op = '<'
@BaseSpatialField.register_lookup
class DistanceLTELookup(DistanceLookupBase):
class DistanceLTELookup(DistanceLookupFromFunction):
lookup_name = 'distance_lte'
op = '<='

View File

@ -1,5 +1,5 @@
from django.contrib.gis.db.models.functions import (
Area, Distance, Length, Perimeter, Transform,
Area, Distance, Intersection, Length, Perimeter, Transform,
)
from django.contrib.gis.geos import GEOSGeometry, LineString, Point
from django.contrib.gis.measure import D # alias for Distance
@ -206,6 +206,13 @@ class DistanceTest(TestCase):
).order_by('name')
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
# With a complex geometry expression
self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0)))
self.assertEqual(
SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(),
SouthTexasCity.objects.count(),
)
'''
=============================