Fixed #28006 -- Allowed using D with lookups on Distance annotations.

This commit is contained in:
Sergey Fedoseev 2017-04-07 04:27:45 +05:00 committed by Tim Graham
parent dbfcedb499
commit fd892f3443
3 changed files with 50 additions and 20 deletions

View File

@ -1,7 +1,7 @@
from decimal import Decimal
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
from django.contrib.gis.db.models.sql import AreaField
from django.contrib.gis.db.models.sql import AreaField, DistanceField
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure,
@ -126,7 +126,7 @@ class OracleToleranceMixin:
def as_oracle(self, compiler, connection):
tol = self.extra.get('tolerance', self.tolerance)
return super().as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
class Area(OracleToleranceMixin, GeoFunc):
@ -247,27 +247,31 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
class DistanceResultMixin:
output_field_class = DistanceField
def source_is_geography(self):
return self.get_source_fields()[0].geography and self.srid == 4326
def convert_value(self, value, expression, connection, context):
if value is None:
return None
def distance_att(self, connection):
dist_att = None
geo_field = self.geo_field
if geo_field.geodetic(connection):
if connection.features.supports_distance_geodetic:
dist_att = 'm'
else:
dist_att = geo_field.units_name(connection)
if dist_att:
return DistanceMeasure(**{dist_att: value})
return value
units = geo_field.units_name(connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
return dist_att
def as_sql(self, compiler, connection, **extra_context):
clone = self.copy()
clone.output_field.distance_att = self.distance_att(connection)
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
geom_param_pos = (0, 1)
output_field_class = FloatField
spheroid = None
def __init__(self, expr1, expr2, spheroid=None, **extra):
@ -358,17 +362,15 @@ class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
output_field_class = FloatField
def __init__(self, expr1, spheroid=True, **extra):
self.spheroid = spheroid
super().__init__(expr1, **extra)
def as_sql(self, compiler, connection):
def as_sql(self, compiler, connection, **extra_context):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super().as_sql(compiler, connection)
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection):
function = None
@ -425,7 +427,6 @@ class NumPoints(GeoFunc):
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
output_field_class = FloatField
arity = 1
def as_postgresql(self, compiler, connection):

View File

@ -49,15 +49,27 @@ class AreaField(models.FloatField):
return 'AreaField'
class DistanceField(BaseField):
class DistanceField(models.FloatField):
"Wrapper for Distance values."
def __init__(self, distance_att):
def __init__(self, distance_att=None):
self.distance_att = distance_att
def from_db_value(self, value, expression, connection, context):
if value is not None:
value = Distance(**{self.distance_att: value})
def get_prep_value(self, value):
if isinstance(value, Distance):
return value
return super().get_prep_value(value)
def get_db_prep_value(self, value, connection, prepared=False):
if not isinstance(value, Distance):
return value
if not self.distance_att:
raise ValueError('Distance measure is supplied, but units are unknown for result.')
return getattr(value, self.distance_att)
def from_db_value(self, value, expression, connection, context):
if value is None or not self.distance_att:
return value
return Distance(**{self.distance_att: value})
def get_internal_type(self):
return 'DistanceField'

View File

@ -368,6 +368,23 @@ class DistanceFunctionsTests(TestCase):
).first().d
self.assertEqual(distance, 1)
@skipUnlessDBFeature("has_Distance_function")
def test_distance_function_d_lookup(self):
qs = Interstate.objects.annotate(
d=Distance(Point(0, 0, srid=3857), Point(0, 1, srid=3857)),
).filter(d=D(m=1))
self.assertTrue(qs.exists())
@skipIfDBFeature("supports_distance_geodetic")
@skipUnlessDBFeature("has_Distance_function")
def test_distance_function_raw_result_d_lookup(self):
qs = Interstate.objects.annotate(
d=Distance(Point(0, 0, srid=4326), Point(0, 1, srid=4326)),
).filter(d=D(m=1))
msg = 'Distance measure is supplied, but units are unknown for result.'
with self.assertRaisesMessage(ValueError, msg):
list(qs)
@no_oracle # Oracle already handles geographic distance calculation.
@skipUnlessDBFeature("has_Distance_function", 'has_Transform_function')
def test_distance_transform(self):