Fixed #28006 -- Allowed using D with lookups on Distance annotations.
This commit is contained in:
parent
dbfcedb499
commit
fd892f3443
|
@ -1,7 +1,7 @@
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
||||||
from django.contrib.gis.db.models.sql import AreaField
|
from django.contrib.gis.db.models.sql import AreaField, DistanceField
|
||||||
from django.contrib.gis.geometry.backend import Geometry
|
from django.contrib.gis.geometry.backend import Geometry
|
||||||
from django.contrib.gis.measure import (
|
from django.contrib.gis.measure import (
|
||||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||||
|
@ -126,7 +126,7 @@ class OracleToleranceMixin:
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection):
|
||||||
tol = self.extra.get('tolerance', self.tolerance)
|
tol = self.extra.get('tolerance', self.tolerance)
|
||||||
return super().as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
|
return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
|
||||||
|
|
||||||
|
|
||||||
class Area(OracleToleranceMixin, GeoFunc):
|
class Area(OracleToleranceMixin, GeoFunc):
|
||||||
|
@ -247,27 +247,31 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class DistanceResultMixin:
|
class DistanceResultMixin:
|
||||||
|
output_field_class = DistanceField
|
||||||
|
|
||||||
def source_is_geography(self):
|
def source_is_geography(self):
|
||||||
return self.get_source_fields()[0].geography and self.srid == 4326
|
return self.get_source_fields()[0].geography and self.srid == 4326
|
||||||
|
|
||||||
def convert_value(self, value, expression, connection, context):
|
def distance_att(self, connection):
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
dist_att = None
|
dist_att = None
|
||||||
geo_field = self.geo_field
|
geo_field = self.geo_field
|
||||||
if geo_field.geodetic(connection):
|
if geo_field.geodetic(connection):
|
||||||
if connection.features.supports_distance_geodetic:
|
if connection.features.supports_distance_geodetic:
|
||||||
dist_att = 'm'
|
dist_att = 'm'
|
||||||
else:
|
else:
|
||||||
dist_att = geo_field.units_name(connection)
|
units = geo_field.units_name(connection)
|
||||||
if dist_att:
|
if units:
|
||||||
return DistanceMeasure(**{dist_att: value})
|
dist_att = DistanceMeasure.unit_attname(units)
|
||||||
return value
|
return dist_att
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
|
clone = self.copy()
|
||||||
|
clone.output_field.distance_att = self.distance_att(connection)
|
||||||
|
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
|
|
||||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
geom_param_pos = (0, 1)
|
geom_param_pos = (0, 1)
|
||||||
output_field_class = FloatField
|
|
||||||
spheroid = None
|
spheroid = None
|
||||||
|
|
||||||
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
||||||
|
@ -358,17 +362,15 @@ class IsValid(OracleToleranceMixin, GeoFuncMixin, Transform):
|
||||||
|
|
||||||
|
|
||||||
class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
output_field_class = FloatField
|
|
||||||
|
|
||||||
def __init__(self, expr1, spheroid=True, **extra):
|
def __init__(self, expr1, spheroid=True, **extra):
|
||||||
self.spheroid = spheroid
|
self.spheroid = spheroid
|
||||||
super().__init__(expr1, **extra)
|
super().__init__(expr1, **extra)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||||
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||||
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
function = None
|
function = None
|
||||||
|
@ -425,7 +427,6 @@ class NumPoints(GeoFunc):
|
||||||
|
|
||||||
|
|
||||||
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
output_field_class = FloatField
|
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
|
|
|
@ -49,15 +49,27 @@ class AreaField(models.FloatField):
|
||||||
return 'AreaField'
|
return 'AreaField'
|
||||||
|
|
||||||
|
|
||||||
class DistanceField(BaseField):
|
class DistanceField(models.FloatField):
|
||||||
"Wrapper for Distance values."
|
"Wrapper for Distance values."
|
||||||
def __init__(self, distance_att):
|
def __init__(self, distance_att=None):
|
||||||
self.distance_att = distance_att
|
self.distance_att = distance_att
|
||||||
|
|
||||||
def from_db_value(self, value, expression, connection, context):
|
def get_prep_value(self, value):
|
||||||
if value is not None:
|
if isinstance(value, Distance):
|
||||||
value = Distance(**{self.distance_att: value})
|
|
||||||
return value
|
return value
|
||||||
|
return super().get_prep_value(value)
|
||||||
|
|
||||||
|
def get_db_prep_value(self, value, connection, prepared=False):
|
||||||
|
if not isinstance(value, Distance):
|
||||||
|
return value
|
||||||
|
if not self.distance_att:
|
||||||
|
raise ValueError('Distance measure is supplied, but units are unknown for result.')
|
||||||
|
return getattr(value, self.distance_att)
|
||||||
|
|
||||||
|
def from_db_value(self, value, expression, connection, context):
|
||||||
|
if value is None or not self.distance_att:
|
||||||
|
return value
|
||||||
|
return Distance(**{self.distance_att: value})
|
||||||
|
|
||||||
def get_internal_type(self):
|
def get_internal_type(self):
|
||||||
return 'DistanceField'
|
return 'DistanceField'
|
||||||
|
|
|
@ -368,6 +368,23 @@ class DistanceFunctionsTests(TestCase):
|
||||||
).first().d
|
).first().d
|
||||||
self.assertEqual(distance, 1)
|
self.assertEqual(distance, 1)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("has_Distance_function")
|
||||||
|
def test_distance_function_d_lookup(self):
|
||||||
|
qs = Interstate.objects.annotate(
|
||||||
|
d=Distance(Point(0, 0, srid=3857), Point(0, 1, srid=3857)),
|
||||||
|
).filter(d=D(m=1))
|
||||||
|
self.assertTrue(qs.exists())
|
||||||
|
|
||||||
|
@skipIfDBFeature("supports_distance_geodetic")
|
||||||
|
@skipUnlessDBFeature("has_Distance_function")
|
||||||
|
def test_distance_function_raw_result_d_lookup(self):
|
||||||
|
qs = Interstate.objects.annotate(
|
||||||
|
d=Distance(Point(0, 0, srid=4326), Point(0, 1, srid=4326)),
|
||||||
|
).filter(d=D(m=1))
|
||||||
|
msg = 'Distance measure is supplied, but units are unknown for result.'
|
||||||
|
with self.assertRaisesMessage(ValueError, msg):
|
||||||
|
list(qs)
|
||||||
|
|
||||||
@no_oracle # Oracle already handles geographic distance calculation.
|
@no_oracle # Oracle already handles geographic distance calculation.
|
||||||
@skipUnlessDBFeature("has_Distance_function", 'has_Transform_function')
|
@skipUnlessDBFeature("has_Distance_function", 'has_Transform_function')
|
||||||
def test_distance_transform(self):
|
def test_distance_transform(self):
|
||||||
|
|
Loading…
Reference in New Issue