Fixed #28006 -- Allowed using D with lookups on Distance annotations.

This commit is contained in:
Sergey Fedoseev 2017-04-07 04:27:45 +05:00 committed by Tim Graham
parent dbfcedb499
commit fd892f3443
3 changed files with 50 additions and 20 deletions

View File

@ -1,7 +1,7 @@
from decimal import Decimal from decimal import Decimal
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
from django.contrib.gis.db.models.sql import AreaField from django.contrib.gis.db.models.sql import AreaField, DistanceField
from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import ( from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure, Area as AreaMeasure, Distance as DistanceMeasure,
@ -126,7 +126,7 @@ class OracleToleranceMixin:
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):
tol = self.extra.get('tolerance', self.tolerance) tol = self.extra.get('tolerance', self.tolerance)
return super().as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
class Area(OracleToleranceMixin, GeoFunc): class Area(OracleToleranceMixin, GeoFunc):
@ -247,27 +247,31 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
class DistanceResultMixin: class DistanceResultMixin:
output_field_class = DistanceField
def source_is_geography(self): def source_is_geography(self):
return self.get_source_fields()[0].geography and self.srid == 4326 return self.get_source_fields()[0].geography and self.srid == 4326
def convert_value(self, value, expression, connection, context): def distance_att(self, connection):
if value is None:
return None
dist_att = None dist_att = None
geo_field = self.geo_field geo_field = self.geo_field
if geo_field.geodetic(connection): if geo_field.geodetic(connection):
if connection.features.supports_distance_geodetic: if connection.features.supports_distance_geodetic:
dist_att = 'm' dist_att = 'm'
else: else:
dist_att = geo_field.units_name(connection) units = geo_field.units_name(connection)
if dist_att: if units:
return DistanceMeasure(**{dist_att: value}) dist_att = DistanceMeasure.unit_attname(units)
return value return dist_att
def as_sql(self, compiler, connection, **extra_context):
clone = self.copy()
clone.output_field.distance_att = self.distance_att(connection)
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
geom_param_pos = (0, 1) geom_param_pos = (0, 1)
output_field_class = FloatField
spheroid = None spheroid = None
def __init__(self, expr1, expr2, spheroid=None, **extra): def __init__(self, expr1, expr2, spheroid=None, **extra):
@ -358,17 +362,15 @@ class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
output_field_class = FloatField
def __init__(self, expr1, spheroid=True, **extra): def __init__(self, expr1, spheroid=True, **extra):
self.spheroid = spheroid self.spheroid = spheroid
super().__init__(expr1, **extra) super().__init__(expr1, **extra)
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection, **extra_context):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic: if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields") raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
function = None function = None
@ -425,7 +427,6 @@ class NumPoints(GeoFunc):
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
output_field_class = FloatField
arity = 1 arity = 1
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):

View File

@ -49,15 +49,27 @@ class AreaField(models.FloatField):
return 'AreaField' return 'AreaField'
class DistanceField(BaseField): class DistanceField(models.FloatField):
"Wrapper for Distance values." "Wrapper for Distance values."
def __init__(self, distance_att): def __init__(self, distance_att=None):
self.distance_att = distance_att self.distance_att = distance_att
def get_prep_value(self, value):
if isinstance(value, Distance):
return value
return super().get_prep_value(value)
def get_db_prep_value(self, value, connection, prepared=False):
if not isinstance(value, Distance):
return value
if not self.distance_att:
raise ValueError('Distance measure is supplied, but units are unknown for result.')
return getattr(value, self.distance_att)
def from_db_value(self, value, expression, connection, context): def from_db_value(self, value, expression, connection, context):
if value is not None: if value is None or not self.distance_att:
value = Distance(**{self.distance_att: value}) return value
return value return Distance(**{self.distance_att: value})
def get_internal_type(self): def get_internal_type(self):
return 'DistanceField' return 'DistanceField'

View File

@ -368,6 +368,23 @@ class DistanceFunctionsTests(TestCase):
).first().d ).first().d
self.assertEqual(distance, 1) self.assertEqual(distance, 1)
@skipUnlessDBFeature("has_Distance_function")
def test_distance_function_d_lookup(self):
qs = Interstate.objects.annotate(
d=Distance(Point(0, 0, srid=3857), Point(0, 1, srid=3857)),
).filter(d=D(m=1))
self.assertTrue(qs.exists())
@skipIfDBFeature("supports_distance_geodetic")
@skipUnlessDBFeature("has_Distance_function")
def test_distance_function_raw_result_d_lookup(self):
qs = Interstate.objects.annotate(
d=Distance(Point(0, 0, srid=4326), Point(0, 1, srid=4326)),
).filter(d=D(m=1))
msg = 'Distance measure is supplied, but units are unknown for result.'
with self.assertRaisesMessage(ValueError, msg):
list(qs)
@no_oracle # Oracle already handles geographic distance calculation. @no_oracle # Oracle already handles geographic distance calculation.
@skipUnlessDBFeature("has_Distance_function", 'has_Transform_function') @skipUnlessDBFeature("has_Distance_function", 'has_Transform_function')
def test_distance_transform(self): def test_distance_transform(self):