From fd892f3443fe9a35684b7b798a8fe1b07d118e3c Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Fri, 7 Apr 2017 04:27:45 +0500 Subject: [PATCH] Fixed #28006 -- Allowed using D with lookups on Distance annotations. --- django/contrib/gis/db/models/functions.py | 31 ++++++++++--------- .../contrib/gis/db/models/sql/conversion.py | 22 ++++++++++--- tests/gis_tests/distapp/tests.py | 17 ++++++++++ 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 04e81e5428..0f18e7bbde 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -1,7 +1,7 @@ from decimal import Decimal from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField -from django.contrib.gis.db.models.sql import AreaField +from django.contrib.gis.db.models.sql import AreaField, DistanceField from django.contrib.gis.geometry.backend import Geometry from django.contrib.gis.measure import ( Area as AreaMeasure, Distance as DistanceMeasure, @@ -126,7 +126,7 @@ class OracleToleranceMixin: def as_oracle(self, compiler, connection): tol = self.extra.get('tolerance', self.tolerance) - return super().as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) + return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) class Area(OracleToleranceMixin, GeoFunc): @@ -247,27 +247,31 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc): class DistanceResultMixin: + output_field_class = DistanceField + def source_is_geography(self): return self.get_source_fields()[0].geography and self.srid == 4326 - def convert_value(self, value, expression, connection, context): - if value is None: - return None + def distance_att(self, connection): dist_att = None geo_field = self.geo_field if geo_field.geodetic(connection): if connection.features.supports_distance_geodetic: dist_att = 'm' else: - dist_att = geo_field.units_name(connection) - if dist_att: - return DistanceMeasure(**{dist_att: value}) - return value + units = geo_field.units_name(connection) + if units: + dist_att = DistanceMeasure.unit_attname(units) + return dist_att + + def as_sql(self, compiler, connection, **extra_context): + clone = self.copy() + clone.output_field.distance_att = self.distance_att(connection) + return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context) class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): geom_param_pos = (0, 1) - output_field_class = FloatField spheroid = None def __init__(self, expr1, expr2, spheroid=None, **extra): @@ -358,17 +362,15 @@ class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform): class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): - output_field_class = FloatField - def __init__(self, expr1, spheroid=True, **extra): self.spheroid = spheroid super().__init__(expr1, **extra) - def as_sql(self, compiler, connection): + def as_sql(self, compiler, connection, **extra_context): geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic: raise NotImplementedError("This backend doesn't support Length on geodetic fields") - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection, **extra_context) def as_postgresql(self, compiler, connection): function = None @@ -425,7 +427,6 @@ class NumPoints(GeoFunc): class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): - output_field_class = FloatField arity = 1 def as_postgresql(self, compiler, connection): diff --git a/django/contrib/gis/db/models/sql/conversion.py b/django/contrib/gis/db/models/sql/conversion.py index adae01fe50..75e4adaccf 100644 --- a/django/contrib/gis/db/models/sql/conversion.py +++ b/django/contrib/gis/db/models/sql/conversion.py @@ -49,15 +49,27 @@ class AreaField(models.FloatField): return 'AreaField' -class DistanceField(BaseField): +class DistanceField(models.FloatField): "Wrapper for Distance values." - def __init__(self, distance_att): + def __init__(self, distance_att=None): self.distance_att = distance_att + def get_prep_value(self, value): + if isinstance(value, Distance): + return value + return super().get_prep_value(value) + + def get_db_prep_value(self, value, connection, prepared=False): + if not isinstance(value, Distance): + return value + if not self.distance_att: + raise ValueError('Distance measure is supplied, but units are unknown for result.') + return getattr(value, self.distance_att) + def from_db_value(self, value, expression, connection, context): - if value is not None: - value = Distance(**{self.distance_att: value}) - return value + if value is None or not self.distance_att: + return value + return Distance(**{self.distance_att: value}) def get_internal_type(self): return 'DistanceField' diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index aafd5f8035..7865bc4d6c 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -368,6 +368,23 @@ class DistanceFunctionsTests(TestCase): ).first().d self.assertEqual(distance, 1) + @skipUnlessDBFeature("has_Distance_function") + def test_distance_function_d_lookup(self): + qs = Interstate.objects.annotate( + d=Distance(Point(0, 0, srid=3857), Point(0, 1, srid=3857)), + ).filter(d=D(m=1)) + self.assertTrue(qs.exists()) + + @skipIfDBFeature("supports_distance_geodetic") + @skipUnlessDBFeature("has_Distance_function") + def test_distance_function_raw_result_d_lookup(self): + qs = Interstate.objects.annotate( + d=Distance(Point(0, 0, srid=4326), Point(0, 1, srid=4326)), + ).filter(d=D(m=1)) + msg = 'Distance measure is supplied, but units are unknown for result.' + with self.assertRaisesMessage(ValueError, msg): + list(qs) + @no_oracle # Oracle already handles geographic distance calculation. @skipUnlessDBFeature("has_Distance_function", 'has_Transform_function') def test_distance_transform(self):