Fixed #25499 -- Added the ability to pass an expression in distance lookups

Thanks Bibhas Debnath for the report and Tim Graham for the review.
This commit is contained in:
Claude Paroz 2015-10-06 22:05:53 +02:00
parent 4a7b58210d
commit 37d06cfc46
8 changed files with 84 additions and 33 deletions

View File

@ -195,7 +195,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
""" """
return 'MDSYS.SDO_GEOMETRY' return 'MDSYS.SDO_GEOMETRY'
def get_distance(self, f, value, lookup_type): def get_distance(self, f, value, lookup_type, **kwargs):
""" """
Returns the distance parameters given the value and the lookup type. Returns the distance parameters given the value and the lookup type.
On Oracle, geometry columns with a geodetic coordinate system behave On Oracle, geometry columns with a geodetic coordinate system behave

View File

@ -34,14 +34,17 @@ class PostGISOperator(SpatialOperator):
class PostGISDistanceOperator(PostGISOperator): class PostGISDistanceOperator(PostGISOperator):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def as_sql(self, connection, lookup, template_params, sql_params): def as_sql(self, connection, lookup, template_params, sql_params):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection): if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
sql_template = self.sql_template sql_template = self.sql_template
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid': if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'}) template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
# Using distance_spheroid requires the spheroid of the field as
# a parameter.
sql_params.insert(1, lookup.lhs.output_field._spheroid)
else: else:
template_params.update({'op': self.op, 'func': 'ST_Distance_Sphere'}) template_params.update({'op': self.op, 'func': 'ST_Distance_Sphere'})
return sql_template % template_params, sql_params return sql_template % template_params, sql_params
@ -226,7 +229,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
geom_type = f.geom_type geom_type = f.geom_type
return 'geometry(%s,%d)' % (geom_type, f.srid) return 'geometry(%s,%d)' % (geom_type, f.srid)
def get_distance(self, f, dist_val, lookup_type): def get_distance(self, f, dist_val, lookup_type, handle_spheroid=True):
""" """
Retrieve the distance parameters for the given geometry field, Retrieve the distance parameters for the given geometry field,
distance lookup value, and the distance lookup type. distance lookup value, and the distance lookup type.
@ -236,11 +239,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
projected geometry columns. In addition, it has to take into account projected geometry columns. In addition, it has to take into account
the geography column type. the geography column type.
""" """
# Getting the distance parameter and any options. # Getting the distance parameter
if len(dist_val) == 1: value = dist_val[0]
value, option = dist_val[0], None
else:
value, option = dist_val
# Shorthand boolean flags. # Shorthand boolean flags.
geodetic = f.geodetic(self.connection) geodetic = f.geodetic(self.connection)
@ -260,13 +260,17 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
# Assuming the distance is in the units of the field. # Assuming the distance is in the units of the field.
dist_param = value dist_param = value
if (not geography and geodetic and lookup_type != 'dwithin' params = [dist_param]
and option == 'spheroid'): # handle_spheroid *might* be dropped in Django 2.0 as PostGISDistanceOperator
# using distance_spheroid requires the spheroid of the field as # also handles it (#25524).
# a parameter. if handle_spheroid and len(dist_val) > 1:
return [f._spheroid, dist_param] option = dist_val[1]
else: if (not geography and geodetic and lookup_type != 'dwithin'
return [dist_param] and option == 'spheroid'):
# using distance_spheroid requires the spheroid of the field as
# a parameter.
params.insert(0, f._spheroid)
return params
def get_geom_placeholder(self, f, value, compiler): def get_geom_placeholder(self, f, value, compiler):
""" """

View File

@ -175,7 +175,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
""" """
return None return None
def get_distance(self, f, value, lookup_type): def get_distance(self, f, value, lookup_type, **kwargs):
""" """
Returns the distance parameters for the given geometry field, Returns the distance parameters for the given geometry field,
lookup value, and lookup type. SpatiaLite only supports regular lookup value, and lookup type. SpatiaLite only supports regular

View File

@ -16,6 +16,10 @@ class GISLookup(Lookup):
transform_func = None transform_func = None
distance = False distance = False
def __init__(self, *args, **kwargs):
super(GISLookup, self).__init__(*args, **kwargs)
self.template_params = {}
@classmethod @classmethod
def _check_geo_field(cls, opts, lookup): def _check_geo_field(cls, opts, lookup):
""" """
@ -98,7 +102,8 @@ class GISLookup(Lookup):
rhs_sql, rhs_params = self.process_rhs(compiler, connection) rhs_sql, rhs_params = self.process_rhs(compiler, connection)
sql_params.extend(rhs_params) sql_params.extend(rhs_params)
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql} template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s'}
template_params.update(self.template_params)
rhs_op = self.get_rhs_op(connection, rhs_sql) rhs_op = self.get_rhs_op(connection, rhs_sql)
return rhs_op.as_sql(connection, self, template_params, sql_params) return rhs_op.as_sql(connection, self, template_params, sql_params)
@ -302,18 +307,26 @@ gis_lookups['within'] = WithinLookup
class DistanceLookupBase(GISLookup): class DistanceLookupBase(GISLookup):
distance = True distance = True
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s' sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def get_db_prep_lookup(self, value, connection): def process_rhs(self, compiler, connection):
if isinstance(value, (tuple, list)): if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3:
if not 2 <= len(value) <= 3: raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name)
raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name) params = [connection.ops.Adapter(self.rhs[0])]
params = [connection.ops.Adapter(value[0])] # Getting the distance parameter in the units of the field.
# Getting the distance parameter in the units of the field. dist_param = self.rhs[1]
params += connection.ops.get_distance(self.lhs.output_field, value[1:], self.lookup_name) if hasattr(dist_param, 'resolve_expression'):
return ('%s', params) dist_param = dist_param.resolve_expression(compiler.query)
sql, expr_params = compiler.compile(dist_param)
self.template_params['value'] = sql
params.extend(expr_params)
else: else:
return super(DistanceLookupBase, self).get_db_prep_lookup(value, connection) params += connection.ops.get_distance(
self.lhs.output_field, (dist_param,) + self.rhs[2:],
self.lookup_name, handle_spheroid=False
)
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
return (rhs, params)
class DWithinLookup(DistanceLookupBase): class DWithinLookup(DistanceLookupBase):

View File

@ -515,14 +515,20 @@ Distance lookups take the following form::
The value passed into a distance lookup is a tuple; the first two The value passed into a distance lookup is a tuple; the first two
values are mandatory, and are the geometry to calculate distances to, values are mandatory, and are the geometry to calculate distances to,
and a distance value (either a number in units of the field or a and a distance value (either a number in units of the field, a
:class:`~django.contrib.gis.measure.Distance` object). On every :class:`~django.contrib.gis.measure.Distance` object, or a `query expression
distance lookup but :lookup:`dwithin`, an optional <ref/models/expressions>`).
With PostGIS, on every distance lookup but :lookup:`dwithin`, an optional
third element, ``'spheroid'``, may be included to tell GeoDjango third element, ``'spheroid'``, may be included to tell GeoDjango
to use the more accurate spheroid distance calculation functions on to use the more accurate spheroid distance calculation functions on
fields with a geodetic coordinate system (e.g., ``ST_Distance_Spheroid`` fields with a geodetic coordinate system (e.g., ``ST_Distance_Spheroid``
would be used instead of ``ST_Distance_Sphere``). would be used instead of ``ST_Distance_Sphere``).
.. versionadded:: 1.10
The ability to pass an expression as the distance value was added.
.. fieldlookup:: distance_gt .. fieldlookup:: distance_gt
distance_gt distance_gt

View File

@ -59,7 +59,8 @@ Minor features
:mod:`django.contrib.gis` :mod:`django.contrib.gis`
^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^
* ... * :ref:`Distance lookups <distance-lookups>` now accept expressions as the
distance value parameter.
:mod:`django.contrib.messages` :mod:`django.contrib.messages`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -21,6 +21,7 @@ class NamedModel(models.Model):
class SouthTexasCity(NamedModel): class SouthTexasCity(NamedModel):
"City model on projected coordinate system for South Texas." "City model on projected coordinate system for South Texas."
point = models.PointField(srid=32140) point = models.PointField(srid=32140)
radius = models.IntegerField(default=10000)
class SouthTexasCityFt(NamedModel): class SouthTexasCityFt(NamedModel):
@ -31,6 +32,7 @@ class SouthTexasCityFt(NamedModel):
class AustraliaCity(NamedModel): class AustraliaCity(NamedModel):
"City model for Australia, using WGS84." "City model for Australia, using WGS84."
point = models.PointField() point = models.PointField()
radius = models.IntegerField(default=10000)
class CensusZipcode(NamedModel): class CensusZipcode(NamedModel):

View File

@ -6,7 +6,7 @@ from django.contrib.gis.db.models.functions import (
from django.contrib.gis.geos import GEOSGeometry, LineString, Point from django.contrib.gis.geos import GEOSGeometry, LineString, Point
from django.contrib.gis.measure import D # alias for Distance from django.contrib.gis.measure import D # alias for Distance
from django.db import connection from django.db import connection
from django.db.models import Q from django.db.models import F, Q
from django.test import TestCase, ignore_warnings, skipUnlessDBFeature from django.test import TestCase, ignore_warnings, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango20Warning from django.utils.deprecation import RemovedInDjango20Warning
@ -323,6 +323,31 @@ class DistanceTest(TestCase):
cities = self.get_names(qs) cities = self.get_names(qs)
self.assertEqual(cities, ['Adelaide', 'Hobart', 'Shellharbour', 'Thirroul']) self.assertEqual(cities, ['Adelaide', 'Hobart', 'Shellharbour', 'Thirroul'])
@skipUnlessDBFeature("supports_distances_lookups")
def test_distance_lookups_with_expression_rhs(self):
qs = SouthTexasCity.objects.filter(
point__distance_lte=(self.stx_pnt, F('radius')),
).order_by('name')
self.assertEqual(
self.get_names(qs),
['Bellaire', 'Downtown Houston', 'Southside Place', 'West University Place']
)
# With a combined expression
qs = SouthTexasCity.objects.filter(
point__distance_lte=(self.stx_pnt, F('radius') * 2),
).order_by('name')
self.assertEqual(len(qs), 5)
self.assertIn('Pearland', self.get_names(qs))
# With spheroid param
if connection.features.supports_distance_geodetic:
hobart = AustraliaCity.objects.get(name='Hobart')
qs = AustraliaCity.objects.filter(
point__distance_lte=(hobart.point, F('radius') * 70, 'spheroid'),
).order_by('name')
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
@skipUnlessDBFeature("has_area_method") @skipUnlessDBFeature("has_area_method")
@ignore_warnings(category=RemovedInDjango20Warning) @ignore_warnings(category=RemovedInDjango20Warning)
def test_area(self): def test_area(self):