Fixed #28006 -- Allowed using D with lookups on Distance annotations.
This commit is contained in:
parent
dbfcedb499
commit
fd892f3443
|
@ -1,7 +1,7 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
||||
from django.contrib.gis.db.models.sql import AreaField
|
||||
from django.contrib.gis.db.models.sql import AreaField, DistanceField
|
||||
from django.contrib.gis.geometry.backend import Geometry
|
||||
from django.contrib.gis.measure import (
|
||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||
|
@ -126,7 +126,7 @@ class OracleToleranceMixin:
|
|||
|
||||
def as_oracle(self, compiler, connection):
|
||||
tol = self.extra.get('tolerance', self.tolerance)
|
||||
return super().as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
|
||||
return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
|
||||
|
||||
|
||||
class Area(OracleToleranceMixin, GeoFunc):
|
||||
|
@ -247,27 +247,31 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
|
|||
|
||||
|
||||
class DistanceResultMixin:
|
||||
output_field_class = DistanceField
|
||||
|
||||
def source_is_geography(self):
|
||||
return self.get_source_fields()[0].geography and self.srid == 4326
|
||||
|
||||
def convert_value(self, value, expression, connection, context):
|
||||
if value is None:
|
||||
return None
|
||||
def distance_att(self, connection):
|
||||
dist_att = None
|
||||
geo_field = self.geo_field
|
||||
if geo_field.geodetic(connection):
|
||||
if connection.features.supports_distance_geodetic:
|
||||
dist_att = 'm'
|
||||
else:
|
||||
dist_att = geo_field.units_name(connection)
|
||||
if dist_att:
|
||||
return DistanceMeasure(**{dist_att: value})
|
||||
return value
|
||||
units = geo_field.units_name(connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
return dist_att
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
clone = self.copy()
|
||||
clone.output_field.distance_att = self.distance_att(connection)
|
||||
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
geom_param_pos = (0, 1)
|
||||
output_field_class = FloatField
|
||||
spheroid = None
|
||||
|
||||
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
||||
|
@ -358,17 +362,15 @@ class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
|
|||
|
||||
|
||||
class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
output_field_class = FloatField
|
||||
|
||||
def __init__(self, expr1, spheroid=True, **extra):
|
||||
self.spheroid = spheroid
|
||||
super().__init__(expr1, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
||||
return super().as_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
function = None
|
||||
|
@ -425,7 +427,6 @@ class NumPoints(GeoFunc):
|
|||
|
||||
|
||||
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
output_field_class = FloatField
|
||||
arity = 1
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
|
|
|
@ -49,15 +49,27 @@ class AreaField(models.FloatField):
|
|||
return 'AreaField'
|
||||
|
||||
|
||||
class DistanceField(BaseField):
|
||||
class DistanceField(models.FloatField):
|
||||
"Wrapper for Distance values."
|
||||
def __init__(self, distance_att):
|
||||
def __init__(self, distance_att=None):
|
||||
self.distance_att = distance_att
|
||||
|
||||
def from_db_value(self, value, expression, connection, context):
|
||||
if value is not None:
|
||||
value = Distance(**{self.distance_att: value})
|
||||
def get_prep_value(self, value):
|
||||
if isinstance(value, Distance):
|
||||
return value
|
||||
return super().get_prep_value(value)
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not isinstance(value, Distance):
|
||||
return value
|
||||
if not self.distance_att:
|
||||
raise ValueError('Distance measure is supplied, but units are unknown for result.')
|
||||
return getattr(value, self.distance_att)
|
||||
|
||||
def from_db_value(self, value, expression, connection, context):
|
||||
if value is None or not self.distance_att:
|
||||
return value
|
||||
return Distance(**{self.distance_att: value})
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'DistanceField'
|
||||
|
|
|
@ -368,6 +368,23 @@ class DistanceFunctionsTests(TestCase):
|
|||
).first().d
|
||||
self.assertEqual(distance, 1)
|
||||
|
||||
@skipUnlessDBFeature("has_Distance_function")
|
||||
def test_distance_function_d_lookup(self):
|
||||
qs = Interstate.objects.annotate(
|
||||
d=Distance(Point(0, 0, srid=3857), Point(0, 1, srid=3857)),
|
||||
).filter(d=D(m=1))
|
||||
self.assertTrue(qs.exists())
|
||||
|
||||
@skipIfDBFeature("supports_distance_geodetic")
|
||||
@skipUnlessDBFeature("has_Distance_function")
|
||||
def test_distance_function_raw_result_d_lookup(self):
|
||||
qs = Interstate.objects.annotate(
|
||||
d=Distance(Point(0, 0, srid=4326), Point(0, 1, srid=4326)),
|
||||
).filter(d=D(m=1))
|
||||
msg = 'Distance measure is supplied, but units are unknown for result.'
|
||||
with self.assertRaisesMessage(ValueError, msg):
|
||||
list(qs)
|
||||
|
||||
@no_oracle # Oracle already handles geographic distance calculation.
|
||||
@skipUnlessDBFeature("has_Distance_function", 'has_Transform_function')
|
||||
def test_distance_transform(self):
|
||||
|
|
Loading…
Reference in New Issue