Fixed #28432 -- Allowed geometry expressions to be used with distance lookups.
Distance lookups use the Distance function for decreased code redundancy.
This commit is contained in:
parent
c7d58c6f43
commit
38af496b98
|
@ -1,3 +1,6 @@
|
|||
from django.contrib.gis.db.models.functions import Distance
|
||||
|
||||
|
||||
class BaseSpatialOperations:
|
||||
# Quick booleans for the type of this spatial backend, and
|
||||
# an attribute for the spatial database version tuple (if applicable)
|
||||
|
@ -113,3 +116,5 @@ class BaseSpatialOperations:
|
|||
|
||||
def spatial_ref_sys(self):
|
||||
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
|
||||
|
||||
distance_expr_for_lookup = staticmethod(Distance)
|
||||
|
|
|
@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator):
|
|||
sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'"
|
||||
|
||||
|
||||
class SDODistance(SpatialOperator):
|
||||
sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE
|
||||
|
||||
|
||||
class SDODWithin(SpatialOperator):
|
||||
sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'"
|
||||
|
||||
|
@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch'
|
||||
'touches': SDOOperator(func='SDO_TOUCH'),
|
||||
'within': SDOOperator(func='SDO_INSIDE'),
|
||||
'distance_gt': SDODistance(op='>'),
|
||||
'distance_gte': SDODistance(op='>='),
|
||||
'distance_lt': SDODistance(op='<'),
|
||||
'distance_lte': SDODistance(op='<='),
|
||||
'dwithin': SDODWithin(),
|
||||
}
|
||||
|
||||
|
|
|
@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import (
|
|||
BaseSpatialOperations,
|
||||
)
|
||||
from django.contrib.gis.db.backends.utils import SpatialOperator
|
||||
from django.contrib.gis.db.models import GeometryField, RasterField
|
||||
from django.contrib.gis.gdal import GDALRaster
|
||||
from django.contrib.gis.measure import Distance
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.postgresql.operations import DatabaseOperations
|
||||
from django.db.models import Func, Value
|
||||
from django.db.utils import ProgrammingError
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.version import get_version_tuple
|
||||
|
@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator):
|
|||
return template_params
|
||||
|
||||
|
||||
class PostGISDistanceOperator(PostGISOperator):
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
|
||||
class ST_Polygon(Func):
|
||||
function = 'ST_Polygon'
|
||||
|
||||
def as_sql(self, connection, lookup, template_params, sql_params):
|
||||
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
|
||||
template_params = self.check_raster(lookup, template_params)
|
||||
sql_template = self.sql_template
|
||||
if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid':
|
||||
template_params.update({
|
||||
'op': self.op,
|
||||
'func': connection.ops.spatial_function_name('DistanceSpheroid'),
|
||||
})
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
|
||||
# Using DistanceSpheroid requires the spheroid of the field as
|
||||
# a parameter.
|
||||
sql_params.insert(1, lookup.lhs.output_field.spheroid(connection))
|
||||
else:
|
||||
template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')})
|
||||
return sql_template % template_params, sql_params
|
||||
return super().as_sql(connection, lookup, template_params, sql_params)
|
||||
def __init__(self, expr):
|
||||
super().__init__(expr)
|
||||
expr = self.source_expressions[0]
|
||||
if isinstance(expr, Value) and not expr._output_field_or_none:
|
||||
self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid))
|
||||
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return GeometryField(srid=self.source_expressions[0].field.srid)
|
||||
|
||||
|
||||
class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
|
@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
|
||||
'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
|
||||
'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
|
||||
'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
|
||||
'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
|
||||
'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
|
||||
'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
|
||||
}
|
||||
|
||||
unsupported_functions = set()
|
||||
|
@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
def parse_raster(self, value):
|
||||
"""Convert a PostGIS HEX String into a dict readable by GDALRaster."""
|
||||
return from_pgraster(value)
|
||||
|
||||
def distance_expr_for_lookup(self, lhs, rhs, **kwargs):
|
||||
return super().distance_expr_for_lookup(
|
||||
self._normalize_distance_lookup_arg(lhs),
|
||||
self._normalize_distance_lookup_arg(rhs),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_distance_lookup_arg(arg):
|
||||
is_raster = (
|
||||
arg.field.geom_type == 'RASTER'
|
||||
if hasattr(arg, 'field') else
|
||||
isinstance(arg, GDALRaster)
|
||||
)
|
||||
return ST_Polygon(arg) if is_raster else arg
|
||||
|
|
|
@ -17,20 +17,6 @@ from django.utils.functional import cached_property
|
|||
from django.utils.version import get_version_tuple
|
||||
|
||||
|
||||
class SpatiaLiteDistanceOperator(SpatialOperator):
|
||||
def as_sql(self, connection, lookup, template_params, sql_params):
|
||||
if lookup.lhs.output_field.geodetic(connection):
|
||||
# SpatiaLite returns NULL instead of zero on geodetic coordinates
|
||||
sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s'
|
||||
template_params.update({
|
||||
'op': self.op,
|
||||
'func': connection.ops.spatial_function_name('Distance'),
|
||||
})
|
||||
sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid')
|
||||
return sql_template % template_params, sql_params
|
||||
return super().as_sql(connection, lookup, template_params, sql_params)
|
||||
|
||||
|
||||
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
name = 'spatialite'
|
||||
spatialite = True
|
||||
|
@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
'exact': SpatialOperator(func='Equals'),
|
||||
# Distance predicates
|
||||
'dwithin': SpatialOperator(func='PtDistWithin'),
|
||||
'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'),
|
||||
'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='),
|
||||
'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'),
|
||||
'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='),
|
||||
}
|
||||
|
||||
disallowed_aggregates = (aggregates.Extent3D,)
|
||||
|
|
|
@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup):
|
|||
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
|
||||
self.process_band_indices()
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
params = [connection.ops.Adapter(self.rhs)]
|
||||
# Getting the distance parameter in the units of the field.
|
||||
def process_distance(self, compiler, connection):
|
||||
dist_param = self.rhs_params[0]
|
||||
if hasattr(dist_param, 'resolve_expression'):
|
||||
dist_param = dist_param.resolve_expression(compiler.query)
|
||||
sql, expr_params = compiler.compile(dist_param)
|
||||
self.template_params['value'] = sql
|
||||
params.extend(expr_params)
|
||||
else:
|
||||
params += connection.ops.get_distance(
|
||||
self.lhs.output_field, self.rhs_params,
|
||||
self.lookup_name,
|
||||
)
|
||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
|
||||
return (rhs, params)
|
||||
return (
|
||||
compiler.compile(dist_param.resolve_expression(compiler.query))
|
||||
if hasattr(dist_param, 'resolve_expression') else
|
||||
('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
|
||||
)
|
||||
|
||||
|
||||
@BaseSpatialField.register_lookup
|
||||
|
@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase):
|
|||
lookup_name = 'dwithin'
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
dist_sql, dist_params = self.process_distance(compiler, connection)
|
||||
self.template_params['value'] = dist_sql
|
||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
|
||||
return rhs, [connection.ops.Adapter(self.rhs)] + dist_params
|
||||
|
||||
|
||||
class DistanceLookupFromFunction(DistanceLookupBase):
|
||||
def as_sql(self, compiler, connection):
|
||||
spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
|
||||
distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
|
||||
sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
|
||||
dist_sql, dist_params = self.process_distance(compiler, connection)
|
||||
return (
|
||||
'%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
|
||||
params + dist_params,
|
||||
)
|
||||
|
||||
|
||||
@BaseSpatialField.register_lookup
|
||||
class DistanceGTLookup(DistanceLookupBase):
|
||||
class DistanceGTLookup(DistanceLookupFromFunction):
|
||||
lookup_name = 'distance_gt'
|
||||
op = '>'
|
||||
|
||||
|
||||
@BaseSpatialField.register_lookup
|
||||
class DistanceGTELookup(DistanceLookupBase):
|
||||
class DistanceGTELookup(DistanceLookupFromFunction):
|
||||
lookup_name = 'distance_gte'
|
||||
op = '>='
|
||||
|
||||
|
||||
@BaseSpatialField.register_lookup
|
||||
class DistanceLTLookup(DistanceLookupBase):
|
||||
class DistanceLTLookup(DistanceLookupFromFunction):
|
||||
lookup_name = 'distance_lt'
|
||||
op = '<'
|
||||
|
||||
|
||||
@BaseSpatialField.register_lookup
|
||||
class DistanceLTELookup(DistanceLookupBase):
|
||||
class DistanceLTELookup(DistanceLookupFromFunction):
|
||||
lookup_name = 'distance_lte'
|
||||
op = '<='
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from django.contrib.gis.db.models.functions import (
|
||||
Area, Distance, Length, Perimeter, Transform,
|
||||
Area, Distance, Intersection, Length, Perimeter, Transform,
|
||||
)
|
||||
from django.contrib.gis.geos import GEOSGeometry, LineString, Point
|
||||
from django.contrib.gis.measure import D # alias for Distance
|
||||
|
@ -206,6 +206,13 @@ class DistanceTest(TestCase):
|
|||
).order_by('name')
|
||||
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
|
||||
|
||||
# With a complex geometry expression
|
||||
self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0)))
|
||||
self.assertEqual(
|
||||
SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(),
|
||||
SouthTexasCity.objects.count(),
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
|
|
Loading…
Reference in New Issue