Fixed #25499 -- Added the ability to pass an expression in distance lookups

Thanks Bibhas Debnath for the report and Tim Graham for the review.
This commit is contained in:
Claude Paroz 2015-10-06 22:05:53 +02:00
parent 4a7b58210d
commit 37d06cfc46
8 changed files with 84 additions and 33 deletions

View File

@ -195,7 +195,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
"""
return 'MDSYS.SDO_GEOMETRY'
def get_distance(self, f, value, lookup_type):
def get_distance(self, f, value, lookup_type, **kwargs):
"""
Returns the distance parameters given the value and the lookup type.
On Oracle, geometry columns with a geodetic coordinate system behave

View File

@ -34,14 +34,17 @@ class PostGISOperator(SpatialOperator):
class PostGISDistanceOperator(PostGISOperator):
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s'
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def as_sql(self, connection, lookup, template_params, sql_params):
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
sql_template = self.sql_template
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %%s'
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
# Using distance_spheroid requires the spheroid of the field as
# a parameter.
sql_params.insert(1, lookup.lhs.output_field._spheroid)
else:
template_params.update({'op': self.op, 'func': 'ST_Distance_Sphere'})
return sql_template % template_params, sql_params
@ -226,7 +229,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
geom_type = f.geom_type
return 'geometry(%s,%d)' % (geom_type, f.srid)
def get_distance(self, f, dist_val, lookup_type):
def get_distance(self, f, dist_val, lookup_type, handle_spheroid=True):
"""
Retrieve the distance parameters for the given geometry field,
distance lookup value, and the distance lookup type.
@ -236,11 +239,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
projected geometry columns. In addition, it has to take into account
the geography column type.
"""
# Getting the distance parameter and any options.
if len(dist_val) == 1:
value, option = dist_val[0], None
else:
value, option = dist_val
# Getting the distance parameter
value = dist_val[0]
# Shorthand boolean flags.
geodetic = f.geodetic(self.connection)
@ -260,13 +260,17 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
# Assuming the distance is in the units of the field.
dist_param = value
params = [dist_param]
# handle_spheroid *might* be dropped in Django 2.0 as PostGISDistanceOperator
# also handles it (#25524).
if handle_spheroid and len(dist_val) > 1:
option = dist_val[1]
if (not geography and geodetic and lookup_type != 'dwithin'
and option == 'spheroid'):
# using distance_spheroid requires the spheroid of the field as
# a parameter.
return [f._spheroid, dist_param]
else:
return [dist_param]
params.insert(0, f._spheroid)
return params
def get_geom_placeholder(self, f, value, compiler):
"""

View File

@ -175,7 +175,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
"""
return None
def get_distance(self, f, value, lookup_type):
def get_distance(self, f, value, lookup_type, **kwargs):
"""
Returns the distance parameters for the given geometry field,
lookup value, and lookup type. SpatiaLite only supports regular

View File

@ -16,6 +16,10 @@ class GISLookup(Lookup):
transform_func = None
distance = False
def __init__(self, *args, **kwargs):
super(GISLookup, self).__init__(*args, **kwargs)
self.template_params = {}
@classmethod
def _check_geo_field(cls, opts, lookup):
"""
@ -98,7 +102,8 @@ class GISLookup(Lookup):
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
sql_params.extend(rhs_params)
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql}
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s'}
template_params.update(self.template_params)
rhs_op = self.get_rhs_op(connection, rhs_sql)
return rhs_op.as_sql(connection, self, template_params, sql_params)
@ -302,18 +307,26 @@ gis_lookups['within'] = WithinLookup
class DistanceLookupBase(GISLookup):
distance = True
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s'
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
def get_db_prep_lookup(self, value, connection):
if isinstance(value, (tuple, list)):
if not 2 <= len(value) <= 3:
def process_rhs(self, compiler, connection):
if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3:
raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name)
params = [connection.ops.Adapter(value[0])]
params = [connection.ops.Adapter(self.rhs[0])]
# Getting the distance parameter in the units of the field.
params += connection.ops.get_distance(self.lhs.output_field, value[1:], self.lookup_name)
return ('%s', params)
dist_param = self.rhs[1]
if hasattr(dist_param, 'resolve_expression'):
dist_param = dist_param.resolve_expression(compiler.query)
sql, expr_params = compiler.compile(dist_param)
self.template_params['value'] = sql
params.extend(expr_params)
else:
return super(DistanceLookupBase, self).get_db_prep_lookup(value, connection)
params += connection.ops.get_distance(
self.lhs.output_field, (dist_param,) + self.rhs[2:],
self.lookup_name, handle_spheroid=False
)
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
return (rhs, params)
class DWithinLookup(DistanceLookupBase):

View File

@ -515,14 +515,20 @@ Distance lookups take the following form::
The value passed into a distance lookup is a tuple; the first two
values are mandatory, and are the geometry to calculate distances to,
and a distance value (either a number in units of the field or a
:class:`~django.contrib.gis.measure.Distance` object). On every
distance lookup but :lookup:`dwithin`, an optional
and a distance value (either a number in units of the field, a
:class:`~django.contrib.gis.measure.Distance` object, or a `query expression
<ref/models/expressions>`).
With PostGIS, on every distance lookup but :lookup:`dwithin`, an optional
third element, ``'spheroid'``, may be included to tell GeoDjango
to use the more accurate spheroid distance calculation functions on
fields with a geodetic coordinate system (e.g., ``ST_Distance_Spheroid``
would be used instead of ``ST_Distance_Sphere``).
.. versionadded:: 1.10
The ability to pass an expression as the distance value was added.
.. fieldlookup:: distance_gt
distance_gt

View File

@ -59,7 +59,8 @@ Minor features
:mod:`django.contrib.gis`
^^^^^^^^^^^^^^^^^^^^^^^^^^
* ...
* :ref:`Distance lookups <distance-lookups>` now accept expressions as the
distance value parameter.
:mod:`django.contrib.messages`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -21,6 +21,7 @@ class NamedModel(models.Model):
class SouthTexasCity(NamedModel):
"City model on projected coordinate system for South Texas."
point = models.PointField(srid=32140)
radius = models.IntegerField(default=10000)
class SouthTexasCityFt(NamedModel):
@ -31,6 +32,7 @@ class SouthTexasCityFt(NamedModel):
class AustraliaCity(NamedModel):
"City model for Australia, using WGS84."
point = models.PointField()
radius = models.IntegerField(default=10000)
class CensusZipcode(NamedModel):

View File

@ -6,7 +6,7 @@ from django.contrib.gis.db.models.functions import (
from django.contrib.gis.geos import GEOSGeometry, LineString, Point
from django.contrib.gis.measure import D # alias for Distance
from django.db import connection
from django.db.models import Q
from django.db.models import F, Q
from django.test import TestCase, ignore_warnings, skipUnlessDBFeature
from django.utils.deprecation import RemovedInDjango20Warning
@ -323,6 +323,31 @@ class DistanceTest(TestCase):
cities = self.get_names(qs)
self.assertEqual(cities, ['Adelaide', 'Hobart', 'Shellharbour', 'Thirroul'])
@skipUnlessDBFeature("supports_distances_lookups")
def test_distance_lookups_with_expression_rhs(self):
qs = SouthTexasCity.objects.filter(
point__distance_lte=(self.stx_pnt, F('radius')),
).order_by('name')
self.assertEqual(
self.get_names(qs),
['Bellaire', 'Downtown Houston', 'Southside Place', 'West University Place']
)
# With a combined expression
qs = SouthTexasCity.objects.filter(
point__distance_lte=(self.stx_pnt, F('radius') * 2),
).order_by('name')
self.assertEqual(len(qs), 5)
self.assertIn('Pearland', self.get_names(qs))
# With spheroid param
if connection.features.supports_distance_geodetic:
hobart = AustraliaCity.objects.get(name='Hobart')
qs = AustraliaCity.objects.filter(
point__distance_lte=(hobart.point, F('radius') * 70, 'spheroid'),
).order_by('name')
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
@skipUnlessDBFeature("has_area_method")
@ignore_warnings(category=RemovedInDjango20Warning)
def test_area(self):