mirror of https://github.com/django/django.git
Fixed #25499 -- Added the ability to pass an expression in distance lookups
Thanks Bibhas Debnath for the report and Tim Graham for the review.
This commit is contained in:
parent
4a7b58210d
commit
37d06cfc46
|
@ -195,7 +195,7 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
"""
|
||||
return 'MDSYS.SDO_GEOMETRY'
|
||||
|
||||
def get_distance(self, f, value, lookup_type):
|
||||
def get_distance(self, f, value, lookup_type, **kwargs):
|
||||
"""
|
||||
Returns the distance parameters given the value and the lookup type.
|
||||
On Oracle, geometry columns with a geodetic coordinate system behave
|
||||
|
|
|
@ -34,14 +34,17 @@ class PostGISOperator(SpatialOperator):
|
|||
|
||||
|
||||
class PostGISDistanceOperator(PostGISOperator):
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s'
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
|
||||
|
||||
def as_sql(self, connection, lookup, template_params, sql_params):
|
||||
if not lookup.lhs.output_field.geography and lookup.lhs.output_field.geodetic(connection):
|
||||
sql_template = self.sql_template
|
||||
if len(lookup.rhs) == 3 and lookup.rhs[-1] == 'spheroid':
|
||||
template_params.update({'op': self.op, 'func': 'ST_Distance_Spheroid'})
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %%s'
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s, %%s) %(op)s %(value)s'
|
||||
# Using distance_spheroid requires the spheroid of the field as
|
||||
# a parameter.
|
||||
sql_params.insert(1, lookup.lhs.output_field._spheroid)
|
||||
else:
|
||||
template_params.update({'op': self.op, 'func': 'ST_Distance_Sphere'})
|
||||
return sql_template % template_params, sql_params
|
||||
|
@ -226,7 +229,7 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
geom_type = f.geom_type
|
||||
return 'geometry(%s,%d)' % (geom_type, f.srid)
|
||||
|
||||
def get_distance(self, f, dist_val, lookup_type):
|
||||
def get_distance(self, f, dist_val, lookup_type, handle_spheroid=True):
|
||||
"""
|
||||
Retrieve the distance parameters for the given geometry field,
|
||||
distance lookup value, and the distance lookup type.
|
||||
|
@ -236,11 +239,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
projected geometry columns. In addition, it has to take into account
|
||||
the geography column type.
|
||||
"""
|
||||
# Getting the distance parameter and any options.
|
||||
if len(dist_val) == 1:
|
||||
value, option = dist_val[0], None
|
||||
else:
|
||||
value, option = dist_val
|
||||
# Getting the distance parameter
|
||||
value = dist_val[0]
|
||||
|
||||
# Shorthand boolean flags.
|
||||
geodetic = f.geodetic(self.connection)
|
||||
|
@ -260,13 +260,17 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
# Assuming the distance is in the units of the field.
|
||||
dist_param = value
|
||||
|
||||
params = [dist_param]
|
||||
# handle_spheroid *might* be dropped in Django 2.0 as PostGISDistanceOperator
|
||||
# also handles it (#25524).
|
||||
if handle_spheroid and len(dist_val) > 1:
|
||||
option = dist_val[1]
|
||||
if (not geography and geodetic and lookup_type != 'dwithin'
|
||||
and option == 'spheroid'):
|
||||
# using distance_spheroid requires the spheroid of the field as
|
||||
# a parameter.
|
||||
return [f._spheroid, dist_param]
|
||||
else:
|
||||
return [dist_param]
|
||||
params.insert(0, f._spheroid)
|
||||
return params
|
||||
|
||||
def get_geom_placeholder(self, f, value, compiler):
|
||||
"""
|
||||
|
|
|
@ -175,7 +175,7 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
"""
|
||||
return None
|
||||
|
||||
def get_distance(self, f, value, lookup_type):
|
||||
def get_distance(self, f, value, lookup_type, **kwargs):
|
||||
"""
|
||||
Returns the distance parameters for the given geometry field,
|
||||
lookup value, and lookup type. SpatiaLite only supports regular
|
||||
|
|
|
@ -16,6 +16,10 @@ class GISLookup(Lookup):
|
|||
transform_func = None
|
||||
distance = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(GISLookup, self).__init__(*args, **kwargs)
|
||||
self.template_params = {}
|
||||
|
||||
@classmethod
|
||||
def _check_geo_field(cls, opts, lookup):
|
||||
"""
|
||||
|
@ -98,7 +102,8 @@ class GISLookup(Lookup):
|
|||
rhs_sql, rhs_params = self.process_rhs(compiler, connection)
|
||||
sql_params.extend(rhs_params)
|
||||
|
||||
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql}
|
||||
template_params = {'lhs': lhs_sql, 'rhs': rhs_sql, 'value': '%s'}
|
||||
template_params.update(self.template_params)
|
||||
rhs_op = self.get_rhs_op(connection, rhs_sql)
|
||||
return rhs_op.as_sql(connection, self, template_params, sql_params)
|
||||
|
||||
|
@ -302,18 +307,26 @@ gis_lookups['within'] = WithinLookup
|
|||
|
||||
class DistanceLookupBase(GISLookup):
|
||||
distance = True
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %%s'
|
||||
sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s'
|
||||
|
||||
def get_db_prep_lookup(self, value, connection):
|
||||
if isinstance(value, (tuple, list)):
|
||||
if not 2 <= len(value) <= 3:
|
||||
def process_rhs(self, compiler, connection):
|
||||
if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3:
|
||||
raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name)
|
||||
params = [connection.ops.Adapter(value[0])]
|
||||
params = [connection.ops.Adapter(self.rhs[0])]
|
||||
# Getting the distance parameter in the units of the field.
|
||||
params += connection.ops.get_distance(self.lhs.output_field, value[1:], self.lookup_name)
|
||||
return ('%s', params)
|
||||
dist_param = self.rhs[1]
|
||||
if hasattr(dist_param, 'resolve_expression'):
|
||||
dist_param = dist_param.resolve_expression(compiler.query)
|
||||
sql, expr_params = compiler.compile(dist_param)
|
||||
self.template_params['value'] = sql
|
||||
params.extend(expr_params)
|
||||
else:
|
||||
return super(DistanceLookupBase, self).get_db_prep_lookup(value, connection)
|
||||
params += connection.ops.get_distance(
|
||||
self.lhs.output_field, (dist_param,) + self.rhs[2:],
|
||||
self.lookup_name, handle_spheroid=False
|
||||
)
|
||||
rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, params[0], compiler)
|
||||
return (rhs, params)
|
||||
|
||||
|
||||
class DWithinLookup(DistanceLookupBase):
|
||||
|
|
|
@ -515,14 +515,20 @@ Distance lookups take the following form::
|
|||
|
||||
The value passed into a distance lookup is a tuple; the first two
|
||||
values are mandatory, and are the geometry to calculate distances to,
|
||||
and a distance value (either a number in units of the field or a
|
||||
:class:`~django.contrib.gis.measure.Distance` object). On every
|
||||
distance lookup but :lookup:`dwithin`, an optional
|
||||
and a distance value (either a number in units of the field, a
|
||||
:class:`~django.contrib.gis.measure.Distance` object, or a `query expression
|
||||
<ref/models/expressions>`).
|
||||
|
||||
With PostGIS, on every distance lookup but :lookup:`dwithin`, an optional
|
||||
third element, ``'spheroid'``, may be included to tell GeoDjango
|
||||
to use the more accurate spheroid distance calculation functions on
|
||||
fields with a geodetic coordinate system (e.g., ``ST_Distance_Spheroid``
|
||||
would be used instead of ``ST_Distance_Sphere``).
|
||||
|
||||
.. versionadded:: 1.10
|
||||
|
||||
The ability to pass an expression as the distance value was added.
|
||||
|
||||
.. fieldlookup:: distance_gt
|
||||
|
||||
distance_gt
|
||||
|
|
|
@ -59,7 +59,8 @@ Minor features
|
|||
:mod:`django.contrib.gis`
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
* ...
|
||||
* :ref:`Distance lookups <distance-lookups>` now accept expressions as the
|
||||
distance value parameter.
|
||||
|
||||
:mod:`django.contrib.messages`
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -21,6 +21,7 @@ class NamedModel(models.Model):
|
|||
class SouthTexasCity(NamedModel):
|
||||
"City model on projected coordinate system for South Texas."
|
||||
point = models.PointField(srid=32140)
|
||||
radius = models.IntegerField(default=10000)
|
||||
|
||||
|
||||
class SouthTexasCityFt(NamedModel):
|
||||
|
@ -31,6 +32,7 @@ class SouthTexasCityFt(NamedModel):
|
|||
class AustraliaCity(NamedModel):
|
||||
"City model for Australia, using WGS84."
|
||||
point = models.PointField()
|
||||
radius = models.IntegerField(default=10000)
|
||||
|
||||
|
||||
class CensusZipcode(NamedModel):
|
||||
|
|
|
@ -6,7 +6,7 @@ from django.contrib.gis.db.models.functions import (
|
|||
from django.contrib.gis.geos import GEOSGeometry, LineString, Point
|
||||
from django.contrib.gis.measure import D # alias for Distance
|
||||
from django.db import connection
|
||||
from django.db.models import Q
|
||||
from django.db.models import F, Q
|
||||
from django.test import TestCase, ignore_warnings, skipUnlessDBFeature
|
||||
from django.utils.deprecation import RemovedInDjango20Warning
|
||||
|
||||
|
@ -323,6 +323,31 @@ class DistanceTest(TestCase):
|
|||
cities = self.get_names(qs)
|
||||
self.assertEqual(cities, ['Adelaide', 'Hobart', 'Shellharbour', 'Thirroul'])
|
||||
|
||||
@skipUnlessDBFeature("supports_distances_lookups")
|
||||
def test_distance_lookups_with_expression_rhs(self):
|
||||
qs = SouthTexasCity.objects.filter(
|
||||
point__distance_lte=(self.stx_pnt, F('radius')),
|
||||
).order_by('name')
|
||||
self.assertEqual(
|
||||
self.get_names(qs),
|
||||
['Bellaire', 'Downtown Houston', 'Southside Place', 'West University Place']
|
||||
)
|
||||
|
||||
# With a combined expression
|
||||
qs = SouthTexasCity.objects.filter(
|
||||
point__distance_lte=(self.stx_pnt, F('radius') * 2),
|
||||
).order_by('name')
|
||||
self.assertEqual(len(qs), 5)
|
||||
self.assertIn('Pearland', self.get_names(qs))
|
||||
|
||||
# With spheroid param
|
||||
if connection.features.supports_distance_geodetic:
|
||||
hobart = AustraliaCity.objects.get(name='Hobart')
|
||||
qs = AustraliaCity.objects.filter(
|
||||
point__distance_lte=(hobart.point, F('radius') * 70, 'spheroid'),
|
||||
).order_by('name')
|
||||
self.assertEqual(self.get_names(qs), ['Canberra', 'Hobart', 'Melbourne'])
|
||||
|
||||
@skipUnlessDBFeature("has_area_method")
|
||||
@ignore_warnings(category=RemovedInDjango20Warning)
|
||||
def test_area(self):
|
||||
|
|
Loading…
Reference in New Issue