Fixed #28432 -- Allowed geometry expressions to be used with distance lookups.
Distance lookups use the Distance function for decreased code redundancy.
This commit is contained in:
parent
c7d58c6f43
commit
38af496b98
|
@ -1,3 +1,6 @@
|
||||||
|
from django.contrib.gis.db.models.functions import Distance
|
||||||
|
|
||||||
|
|
||||||
class BaseSpatialOperations:
|
class BaseSpatialOperations:
|
||||||
# Quick booleans for the type of this spatial backend, and
|
# Quick booleans for the type of this spatial backend, and
|
||||||
# an attribute for the spatial database version tuple (if applicable)
|
# an attribute for the spatial database version tuple (if applicable)
|
||||||
|
@ -113,3 +116,5 @@ class BaseSpatialOperations:
|
||||||
|
|
||||||
def spatial_ref_sys(self):
|
def spatial_ref_sys(self):
|
||||||
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
|
raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method')
|
||||||
|
|
||||||
|
distance_expr_for_lookup = staticmethod(Distance)
|
||||||
|
|
|
@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator):
|
||||||
sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'"
|
sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'"
|
||||||
|
|
||||||
|
|
||||||
class SDODistance(SpatialOperator):
|
|
||||||
sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE
|
|
||||||
|
|
||||||
|
|
||||||
class SDODWithin(SpatialOperator):
|
class SDODWithin(SpatialOperator):
|
||||||
sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'"
|
sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'"
|
||||||
|
|
||||||
|
@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch'
|
'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch'
|
||||||
'touches': SDOOperator(func='SDO_TOUCH'),
|
'touches': SDOOperator(func='SDO_TOUCH'),
|
||||||
'within': SDOOperator(func='SDO_INSIDE'),
|
'within': SDOOperator(func='SDO_INSIDE'),
|
||||||
'distance_gt': SDODistance(op='>'),
|
|
||||||
'distance_gte': SDODistance(op='>='),
|
|
||||||
'distance_lt': SDODistance(op='<'),
|
|
||||||
'distance_lte': SDODistance(op='<='),
|
|
||||||
'dwithin': SDODWithin(),
|
'dwithin': SDODWithin(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import (
|
||||||
BaseSpatialOperations,
|
BaseSpatialOperations,
|
||||||
)
|
)
|
||||||
from django.contrib.gis.db.backends.utils import SpatialOperator
|
from django.contrib.gis.db.backends.utils import SpatialOperator
|
||||||
|
from django.contrib.gis.db.models import GeometryField, RasterField
|
||||||
from django.contrib.gis.gdal import GDALRaster
|
from django.contrib.gis.gdal import GDALRaster
|
||||||
from django.contrib.gis.measure import Distance
|
from django.contrib.gis.measure import Distance
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.db.backends.postgresql.operations import DatabaseOperations
|
from django.db.backends.postgresql.operations import DatabaseOperations
|
||||||
|
from django.db.models import Func, Value
|
||||||
from django.db.utils import ProgrammingError
|
from django.db.utils import ProgrammingError
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
from django.utils.version import get_version_tuple
|
from django.utils.version import get_version_tuple
|
||||||
|
@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator):
|
||||||
return template_params
|
return template_params
|
||||||
|
|
||||||
|
|
||||||
class PostGISDistanceOperator(PostGISOperator):
|
class ST_Polygon(Func):
|
||||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
|
function = 'ST_Polygon'
|
||||||
|
|
||||||
def as_sql(self, connection, lookup, template_params, sql_params):
|
def __init__(self, expr):
|
||||||
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
|
super().__init__(expr)
|
||||||
template_params = self.check_raster(lookup, template_params)
|
expr = self.source_expressions[0]
|
||||||
sql_template = self.sql_template
|
if isinstance(expr, Value) and not expr._output_field_or_none:
|
||||||
if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid':
|
self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid))
|
||||||
template_params.update({
|
|
||||||
'op': self.op,
|
@cached_property
|
||||||
'func': connection.ops.spatial_function_name('DistanceSpheroid'),
|
def output_field(self):
|
||||||
})
|
return GeometryField(srid=self.source_expressions[0].field.srid)
|
||||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
|
|
||||||
# Using DistanceSpheroid requires the spheroid of the field as
|
|
||||||
# a parameter.
|
|
||||||
sql_params.insert(1, lookup.lhs.output_field.spheroid(connection))
|
|
||||||
else:
|
|
||||||
template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')})
|
|
||||||
return sql_template % template_params, sql_params
|
|
||||||
return super().as_sql(connection, lookup, template_params, sql_params)
|
|
||||||
|
|
||||||
|
|
||||||
class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
|
@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
|
'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL),
|
||||||
'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
|
'within': PostGISOperator(func='ST_Within', raster=BILATERAL),
|
||||||
'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
|
'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL),
|
||||||
'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True),
|
|
||||||
'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True),
|
|
||||||
'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True),
|
|
||||||
'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsupported_functions = set()
|
unsupported_functions = set()
|
||||||
|
@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
def parse_raster(self, value):
|
def parse_raster(self, value):
|
||||||
"""Convert a PostGIS HEX String into a dict readable by GDALRaster."""
|
"""Convert a PostGIS HEX String into a dict readable by GDALRaster."""
|
||||||
return from_pgraster(value)
|
return from_pgraster(value)
|
||||||
|
|
||||||
|
def distance_expr_for_lookup(self, lhs, rhs, **kwargs):
|
||||||
|
return super().distance_expr_for_lookup(
|
||||||
|
self._normalize_distance_lookup_arg(lhs),
|
||||||
|
self._normalize_distance_lookup_arg(rhs),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_distance_lookup_arg(arg):
|
||||||
|
is_raster = (
|
||||||
|
arg.field.geom_type == 'RASTER'
|
||||||
|
if hasattr(arg, 'field') else
|
||||||
|
isinstance(arg, GDALRaster)
|
||||||
|
)
|
||||||
|
return ST_Polygon(arg) if is_raster else arg
|
||||||
|
|
|
@ -17,20 +17,6 @@ from django.utils.functional import cached_property
|
||||||
from django.utils.version import get_version_tuple
|
from django.utils.version import get_version_tuple
|
||||||
|
|
||||||
|
|
||||||
class SpatiaLiteDistanceOperator(SpatialOperator):
|
|
||||||
def as_sql(self, connection, lookup, template_params, sql_params):
|
|
||||||
if lookup.lhs.output_field.geodetic(connection):
|
|
||||||
# SpatiaLite returns NULL instead of zero on geodetic coordinates
|
|
||||||
sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s'
|
|
||||||
template_params.update({
|
|
||||||
'op': self.op,
|
|
||||||
'func': connection.ops.spatial_function_name('Distance'),
|
|
||||||
})
|
|
||||||
sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid')
|
|
||||||
return sql_template % template_params, sql_params
|
|
||||||
return super().as_sql(connection, lookup, template_params, sql_params)
|
|
||||||
|
|
||||||
|
|
||||||
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
name = 'spatialite'
|
name = 'spatialite'
|
||||||
spatialite = True
|
spatialite = True
|
||||||
|
@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
'exact': SpatialOperator(func='Equals'),
|
'exact': SpatialOperator(func='Equals'),
|
||||||
# Distance predicates
|
# Distance predicates
|
||||||
'dwithin': SpatialOperator(func='PtDistWithin'),
|
'dwithin': SpatialOperator(func='PtDistWithin'),
|
||||||
'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'),
|
|
||||||
'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='),
|
|
||||||
'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'),
|
|
||||||
'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
disallowed_aggregates = (aggregates.Extent3D,)
|
disallowed_aggregates = (aggregates.Extent3D,)
|
||||||
|
|
|
@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup):
|
||||||
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
|
if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid':
|
||||||
self.process_band_indices()
|
self.process_band_indices()
|
||||||
|
|
||||||
def process_rhs(self, compiler, connection):
|
def process_distance(self, compiler, connection):
|
||||||
params = [connection.ops.Adapter(self.rhs)]
|
|
||||||
# Getting the distance parameter in the units of the field.
|
|
||||||
dist_param = self.rhs_params[0]
|
dist_param = self.rhs_params[0]
|
||||||
if hasattr(dist_param, 'resolve_expression'):
|
return (
|
||||||
dist_param = dist_param.resolve_expression(compiler.query)
|
compiler.compile(dist_param.resolve_expression(compiler.query))
|
||||||
sql, expr_params = compiler.compile(dist_param)
|
if hasattr(dist_param, 'resolve_expression') else
|
||||||
self.template_params['value'] = sql
|
('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name))
|
||||||
params.extend(expr_params)
|
)
|
||||||
else:
|
|
||||||
params += connection.ops.get_distance(
|
|
||||||
self.lhs.output_field, self.rhs_params,
|
|
||||||
self.lookup_name,
|
|
||||||
)
|
|
||||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
|
|
||||||
return (rhs, params)
|
|
||||||
|
|
||||||
|
|
||||||
@BaseSpatialField.register_lookup
|
@BaseSpatialField.register_lookup
|
||||||
|
@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase):
|
||||||
lookup_name = 'dwithin'
|
lookup_name = 'dwithin'
|
||||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
|
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)'
|
||||||
|
|
||||||
|
def process_rhs(self, compiler, connection):
|
||||||
|
dist_sql, dist_params = self.process_distance(compiler, connection)
|
||||||
|
self.template_params['value'] = dist_sql
|
||||||
|
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler)
|
||||||
|
return rhs, [connection.ops.Adapter(self.rhs)] + dist_params
|
||||||
|
|
||||||
|
|
||||||
|
class DistanceLookupFromFunction(DistanceLookupBase):
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None
|
||||||
|
distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid)
|
||||||
|
sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
|
||||||
|
dist_sql, dist_params = self.process_distance(compiler, connection)
|
||||||
|
return (
|
||||||
|
'%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql},
|
||||||
|
params + dist_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@BaseSpatialField.register_lookup
|
@BaseSpatialField.register_lookup
|
||||||
class DistanceGTLookup(DistanceLookupBase):
|
class DistanceGTLookup(DistanceLookupFromFunction):
|
||||||
lookup_name = 'distance_gt'
|
lookup_name = 'distance_gt'
|
||||||
|
op = '>'
|
||||||
|
|
||||||
|
|
||||||
@BaseSpatialField.register_lookup
|
@BaseSpatialField.register_lookup
|
||||||
class DistanceGTELookup(DistanceLookupBase):
|
class DistanceGTELookup(DistanceLookupFromFunction):
|
||||||
lookup_name = 'distance_gte'
|
lookup_name = 'distance_gte'
|
||||||
|
op = '>='
|
||||||
|
|
||||||
|
|
||||||
@BaseSpatialField.register_lookup
|
@BaseSpatialField.register_lookup
|
||||||
class DistanceLTLookup(DistanceLookupBase):
|
class DistanceLTLookup(DistanceLookupFromFunction):
|
||||||
lookup_name = 'distance_lt'
|
lookup_name = 'distance_lt'
|
||||||
|
op = '<'
|
||||||
|
|
||||||
|
|
||||||
@BaseSpatialField.register_lookup
|
@BaseSpatialField.register_lookup
|
||||||
class DistanceLTELookup(DistanceLookupBase):
|
class DistanceLTELookup(DistanceLookupFromFunction):
|
||||||
lookup_name = 'distance_lte'
|
lookup_name = 'distance_lte'
|
||||||
|
op = '<='
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from django.contrib.gis.db.models.functions import (
|
from django.contrib.gis.db.models.functions import (
|
||||||
Area, Distance, Length, Perimeter, Transform,
|
Area, Distance, Intersection, Length, Perimeter, Transform,
|
||||||
)
|
)
|
||||||
from django.contrib.gis.geos import GEOSGeometry, LineString, Point
|
from django.contrib.gis.geos import GEOSGeometry, LineString, Point
|
||||||
from django.contrib.gis.measure import D # alias for Distance
|
from django.contrib.gis.measure import D # alias for Distance
|
||||||
|
@ -206,6 +206,13 @@ class DistanceTest(TestCase):
|
||||||
).order_by('name')
|
).order_by('name')
|
||||||
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
|
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
|
||||||
|
|
||||||
|
# With a complex geometry expression
|
||||||
|
self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0)))
|
||||||
|
self.assertEqual(
|
||||||
|
SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(),
|
||||||
|
SouthTexasCity.objects.count(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
=============================
|
=============================
|
||||||
|
|
Loading…
Reference in New Issue