diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py index d21198dc0f..3afdd44ef4 100644 --- a/django/contrib/gis/db/backends/base/operations.py +++ b/django/contrib/gis/db/backends/base/operations.py @@ -1,3 +1,6 @@ +from django.contrib.gis.db.models.functions import Distance + + class BaseSpatialOperations: # Quick booleans for the type of this spatial backend, and # an attribute for the spatial database version tuple (if applicable) @@ -113,3 +116,5 @@ class BaseSpatialOperations: def spatial_ref_sys(self): raise NotImplementedError('subclasses of BaseSpatialOperations must a provide spatial_ref_sys() method') + + distance_expr_for_lookup = staticmethod(Distance) diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index d40121e850..57a78ae39a 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -26,10 +26,6 @@ class SDOOperator(SpatialOperator): sql_template = "%(func)s(%(lhs)s, %(rhs)s) = 'TRUE'" -class SDODistance(SpatialOperator): - sql_template = "SDO_GEOM.SDO_DISTANCE(%%(lhs)s, %%(rhs)s, %s) %%(op)s %%(value)s" % DEFAULT_TOLERANCE - - class SDODWithin(SpatialOperator): sql_template = "SDO_WITHIN_DISTANCE(%(lhs)s, %(rhs)s, %%s) = 'TRUE'" @@ -104,10 +100,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): 'relate': SDORelate(), # Oracle uses a different syntax, e.g., 'mask=inside+touch' 'touches': SDOOperator(func='SDO_TOUCH'), 'within': SDOOperator(func='SDO_INSIDE'), - 'distance_gt': SDODistance(op='>'), - 'distance_gte': SDODistance(op='>='), - 'distance_lt': SDODistance(op='<'), - 'distance_lte': SDODistance(op='<='), 'dwithin': SDODWithin(), } diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 7e1712bc7f..335ffc8c8d 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -5,10 +5,12 @@ from django.contrib.gis.db.backends.base.operations import ( BaseSpatialOperations, ) from django.contrib.gis.db.backends.utils import SpatialOperator +from django.contrib.gis.db.models import GeometryField, RasterField from django.contrib.gis.gdal import GDALRaster from django.contrib.gis.measure import Distance from django.core.exceptions import ImproperlyConfigured from django.db.backends.postgresql.operations import DatabaseOperations +from django.db.models import Func, Value from django.db.utils import ProgrammingError from django.utils.functional import cached_property from django.utils.version import get_version_tuple @@ -77,26 +79,18 @@ class PostGISOperator(SpatialOperator): return template_params -class PostGISDistanceOperator(PostGISOperator): - sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' +class ST_Polygon(Func): + function = 'ST_Polygon' - def as_sql(self, connection, lookup, template_params, sql_params): - if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): - template_params = self.check_raster(lookup, template_params) - sql_template = self.sql_template - if len(lookup.rhs_params) == 2 and lookup.rhs_params[-1] == 'spheroid': - template_params.update({ - 'op': self.op, - 'func': connection.ops.spatial_function_name('DistanceSpheroid'), - }) - sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s' - # Using DistanceSpheroid requires the spheroid of the field as - # a parameter. - sql_params.insert(1, lookup.lhs.output_field.spheroid(connection)) - else: - template_params.update({'op': self.op, 'func': connection.ops.spatial_function_name('DistanceSphere')}) - return sql_template % template_params, sql_params - return super().as_sql(connection, lookup, template_params, sql_params) + def __init__(self, expr): + super().__init__(expr) + expr = self.source_expressions[0] + if isinstance(expr, Value) and not expr._output_field_or_none: + self.source_expressions[0] = Value(expr.value, output_field=RasterField(srid=expr.value.srid)) + + @cached_property + def output_field(self): + return GeometryField(srid=self.source_expressions[0].field.srid) class PostGISOperations(BaseSpatialOperations, DatabaseOperations): @@ -134,10 +128,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): 'touches': PostGISOperator(func='ST_Touches', raster=BILATERAL), 'within': PostGISOperator(func='ST_Within', raster=BILATERAL), 'dwithin': PostGISOperator(func='ST_DWithin', geography=True, raster=BILATERAL), - 'distance_gt': PostGISDistanceOperator(func='ST_Distance', op='>', geography=True), - 'distance_gte': PostGISDistanceOperator(func='ST_Distance', op='>=', geography=True), - 'distance_lt': PostGISDistanceOperator(func='ST_Distance', op='<', geography=True), - 'distance_lte': PostGISDistanceOperator(func='ST_Distance', op='<=', geography=True), } unsupported_functions = set() @@ -375,3 +365,19 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): def parse_raster(self, value): """Convert a PostGIS HEX String into a dict readable by GDALRaster.""" return from_pgraster(value) + + def distance_expr_for_lookup(self, lhs, rhs, **kwargs): + return super().distance_expr_for_lookup( + self._normalize_distance_lookup_arg(lhs), + self._normalize_distance_lookup_arg(rhs), + **kwargs + ) + + @staticmethod + def _normalize_distance_lookup_arg(arg): + is_raster = ( + arg.field.geom_type == 'RASTER' + if hasattr(arg, 'field') else + isinstance(arg, GDALRaster) + ) + return ST_Polygon(arg) if is_raster else arg diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 5575164f91..0a84ddb71e 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -17,20 +17,6 @@ from django.utils.functional import cached_property from django.utils.version import get_version_tuple -class SpatiaLiteDistanceOperator(SpatialOperator): - def as_sql(self, connection, lookup, template_params, sql_params): - if lookup.lhs.output_field.geodetic(connection): - # SpatiaLite returns NULL instead of zero on geodetic coordinates - sql_template = 'COALESCE(%(func)s(%(lhs)s, %(rhs)s, %%s), 0) %(op)s %(value)s' - template_params.update({ - 'op': self.op, - 'func': connection.ops.spatial_function_name('Distance'), - }) - sql_params.insert(1, len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid') - return sql_template % template_params, sql_params - return super().as_sql(connection, lookup, template_params, sql_params) - - class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): name = 'spatialite' spatialite = True @@ -68,10 +54,6 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): 'exact': SpatialOperator(func='Equals'), # Distance predicates 'dwithin': SpatialOperator(func='PtDistWithin'), - 'distance_gt': SpatiaLiteDistanceOperator(func='Distance', op='>'), - 'distance_gte': SpatiaLiteDistanceOperator(func='Distance', op='>='), - 'distance_lt': SpatiaLiteDistanceOperator(func='Distance', op='<'), - 'distance_lte': SpatiaLiteDistanceOperator(func='Distance', op='<='), } disallowed_aggregates = (aggregates.Extent3D,) diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 12caf80359..3f3bc8acc8 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -305,22 +305,13 @@ class DistanceLookupBase(GISLookup): if len(self.rhs_params) > 1 and self.rhs_params[1] != 'spheroid': self.process_band_indices() - def process_rhs(self, compiler, connection): - params = [connection.ops.Adapter(self.rhs)] - # Getting the distance parameter in the units of the field. + def process_distance(self, compiler, connection): dist_param = self.rhs_params[0] - if hasattr(dist_param, 'resolve_expression'): - dist_param = dist_param.resolve_expression(compiler.query) - sql, expr_params = compiler.compile(dist_param) - self.template_params['value'] = sql - params.extend(expr_params) - else: - params += connection.ops.get_distance( - self.lhs.output_field, self.rhs_params, - self.lookup_name, - ) - rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler) - return (rhs, params) + return ( + compiler.compile(dist_param.resolve_expression(compiler.query)) + if hasattr(dist_param, 'resolve_expression') else + ('%s', connection.ops.get_distance(self.lhs.output_field, self.rhs_params, self.lookup_name)) + ) @BaseSpatialField.register_lookup @@ -328,22 +319,44 @@ class DWithinLookup(DistanceLookupBase): lookup_name = 'dwithin' sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s)' + def process_rhs(self, compiler, connection): + dist_sql, dist_params = self.process_distance(compiler, connection) + self.template_params['value'] = dist_sql + rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, self.rhs, compiler) + return rhs, [connection.ops.Adapter(self.rhs)] + dist_params + + +class DistanceLookupFromFunction(DistanceLookupBase): + def as_sql(self, compiler, connection): + spheroid = (len(self.rhs_params) == 2 and self.rhs_params[-1] == 'spheroid') or None + distance_expr = connection.ops.distance_expr_for_lookup(self.lhs, self.rhs, spheroid=spheroid) + sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query)) + dist_sql, dist_params = self.process_distance(compiler, connection) + return ( + '%(func)s %(op)s %(dist)s' % {'func': sql, 'op': self.op, 'dist': dist_sql}, + params + dist_params, + ) + @BaseSpatialField.register_lookup -class DistanceGTLookup(DistanceLookupBase): +class DistanceGTLookup(DistanceLookupFromFunction): lookup_name = 'distance_gt' + op = '>' @BaseSpatialField.register_lookup -class DistanceGTELookup(DistanceLookupBase): +class DistanceGTELookup(DistanceLookupFromFunction): lookup_name = 'distance_gte' + op = '>=' @BaseSpatialField.register_lookup -class DistanceLTLookup(DistanceLookupBase): +class DistanceLTLookup(DistanceLookupFromFunction): lookup_name = 'distance_lt' + op = '<' @BaseSpatialField.register_lookup -class DistanceLTELookup(DistanceLookupBase): +class DistanceLTELookup(DistanceLookupFromFunction): lookup_name = 'distance_lte' + op = '<=' diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index 8b8fb90c8c..c4003e14f3 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -1,5 +1,5 @@ from django.contrib.gis.db.models.functions import ( - Area, Distance, Length, Perimeter, Transform, + Area, Distance, Intersection, Length, Perimeter, Transform, ) from django.contrib.gis.geos import GEOSGeometry, LineString, Point from django.contrib.gis.measure import D # alias for Distance @@ -206,6 +206,13 @@ class DistanceTest(TestCase): ).order_by('name') self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne']) + # With a complex geometry expression + self.assertFalse(SouthTexasCity.objects.filter(point__distance_gt=(Intersection('point', 'point'), 0))) + self.assertEqual( + SouthTexasCity.objects.filter(point__distance_lte=(Intersection('point', 'point'), 0)).count(), + SouthTexasCity.objects.count(), + ) + ''' =============================