Fixed #25605 -- Made GIS DB functions accept geometric expressions, not only values, in all positions.
This commit is contained in:
parent
e487ffd3f0
commit
bde86ce9ae
|
@ -1,8 +1,6 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from django.contrib.gis.db.models.fields import (
|
||||
BaseSpatialField, GeometryField, RasterField,
|
||||
)
|
||||
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
||||
from django.contrib.gis.db.models.sql import AreaField
|
||||
from django.contrib.gis.geometry.backend import Geometry
|
||||
from django.contrib.gis.measure import (
|
||||
|
@ -13,6 +11,7 @@ from django.db.models import (
|
|||
BooleanField, FloatField, IntegerField, TextField, Transform,
|
||||
)
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.functions import Cast
|
||||
|
||||
NUMERIC_TYPES = (int, float, Decimal)
|
||||
|
||||
|
@ -20,26 +19,37 @@ NUMERIC_TYPES = (int, float, Decimal)
|
|||
class GeoFuncMixin:
|
||||
function = None
|
||||
output_field_class = None
|
||||
geom_param_pos = 0
|
||||
geom_param_pos = (0,)
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if 'output_field' not in extra and self.output_field_class:
|
||||
extra['output_field'] = self.output_field_class()
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
# Ensure that value expressions are geometric.
|
||||
for pos in self.geom_param_pos:
|
||||
expr = self.source_expressions[pos]
|
||||
if not isinstance(expr, Value):
|
||||
continue
|
||||
try:
|
||||
output_field = expr.output_field
|
||||
except FieldError:
|
||||
output_field = None
|
||||
geom = expr.value
|
||||
if not isinstance(geom, Geometry) or output_field and not isinstance(output_field, GeometryField):
|
||||
raise TypeError("%s function requires a geometric argument in position %d." % (self.name, pos + 1))
|
||||
if not geom.srid and not output_field:
|
||||
raise ValueError("SRID is required for all geometries.")
|
||||
if not output_field:
|
||||
self.source_expressions[pos] = Value(geom, output_field=GeometryField(srid=geom.srid))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
expr = self.source_expressions[self.geom_param_pos]
|
||||
if hasattr(expr, 'srid'):
|
||||
return expr.srid
|
||||
try:
|
||||
return expr.field.srid
|
||||
except (AttributeError, FieldError):
|
||||
return None
|
||||
return self.source_expressions[self.geom_param_pos[0]].field.srid
|
||||
|
||||
@property
|
||||
def geo_field(self):
|
||||
|
@ -48,19 +58,28 @@ class GeoFuncMixin:
|
|||
def as_sql(self, compiler, connection, function=None, **extra_context):
|
||||
if not self.function and not function:
|
||||
function = connection.ops.spatial_function_name(self.name)
|
||||
if any(isinstance(field, RasterField) for field in self.get_source_fields()):
|
||||
raise TypeError("Geometry functions not supported for raster fields.")
|
||||
return super().as_sql(compiler, connection, function=function, **extra_context)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
res = super().resolve_expression(*args, **kwargs)
|
||||
base_srid = res.srid
|
||||
if not base_srid:
|
||||
raise TypeError("Geometry functions can only operate on geometric content.")
|
||||
|
||||
for pos, expr in enumerate(res.source_expressions[1:], start=1):
|
||||
if isinstance(expr, GeomValue) and expr.srid != base_srid:
|
||||
# Automatic SRID conversion so objects are comparable
|
||||
# Ensure that expressions are geometric.
|
||||
source_fields = res.get_source_fields()
|
||||
for pos in self.geom_param_pos:
|
||||
field = source_fields[pos]
|
||||
if not isinstance(field, GeometryField):
|
||||
raise TypeError(
|
||||
"%s function requires a GeometryField in position %s, got %s." % (
|
||||
self.name, pos + 1, type(field).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
base_srid = res.srid
|
||||
for pos in self.geom_param_pos[1:]:
|
||||
expr = res.source_expressions[pos]
|
||||
expr_srid = expr.output_field.srid
|
||||
if expr_srid != base_srid:
|
||||
# Automatic SRID conversion so objects are comparable.
|
||||
res.source_expressions[pos] = Transform(expr, base_srid).resolve_expression(*args, **kwargs)
|
||||
return res
|
||||
|
||||
|
@ -78,34 +97,16 @@ class GeoFunc(GeoFuncMixin, Func):
|
|||
pass
|
||||
|
||||
|
||||
class GeomValue(Value):
|
||||
geography = False
|
||||
class GeomOutputGeoFunc(GeoFunc):
|
||||
def __init__(self, *expressions, **extra):
|
||||
if 'output_field' not in extra:
|
||||
extra['output_field'] = GeometryField()
|
||||
super(GeomOutputGeoFunc, self).__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
return self.value.srid
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return '%s(%%s, %s)' % (connection.ops.from_text, self.srid), [connection.ops.Adapter(self.value)]
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return '%s(%%s)' % (connection.ops.from_text), [connection.ops.Adapter(self.value)]
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if self.geography:
|
||||
self.value = connection.ops.Adapter(self.value, geography=self.geography)
|
||||
else:
|
||||
self.value = connection.ops.Adapter(self.value)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class GeoFuncWithGeoParam(GeoFunc):
|
||||
def __init__(self, expression, geom, *expressions, **extra):
|
||||
if not isinstance(geom, Geometry):
|
||||
raise TypeError("Please provide a geometry object.")
|
||||
if not hasattr(geom, 'srid') or not geom.srid:
|
||||
raise ValueError("Please provide a geometry attribute with a defined SRID.")
|
||||
super().__init__(expression, GeomValue(geom), *expressions, **extra)
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
res = super().resolve_expression(*args, **kwargs)
|
||||
res.output_field.srid = res.srid
|
||||
return res
|
||||
|
||||
|
||||
class SQLiteDecimalToFloatMixin:
|
||||
|
@ -181,7 +182,7 @@ class AsGeoJSON(GeoFunc):
|
|||
|
||||
|
||||
class AsGML(GeoFunc):
|
||||
geom_param_pos = 1
|
||||
geom_param_pos = (1,)
|
||||
output_field_class = TextField
|
||||
|
||||
def __init__(self, expression, version=2, precision=8, **extra):
|
||||
|
@ -230,12 +231,13 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc):
|
|||
return super(BoundingCircle, clone).as_oracle(compiler, connection)
|
||||
|
||||
|
||||
class Centroid(OracleToleranceMixin, GeoFunc):
|
||||
class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
arity = 1
|
||||
|
||||
|
||||
class Difference(OracleToleranceMixin, GeoFuncWithGeoParam):
|
||||
class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
arity = 2
|
||||
geom_param_pos = (0, 1)
|
||||
|
||||
|
||||
class DistanceResultMixin:
|
||||
|
@ -259,7 +261,8 @@ class DistanceResultMixin:
|
|||
return value
|
||||
|
||||
|
||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam):
|
||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
geom_param_pos = (0, 1)
|
||||
output_field_class = FloatField
|
||||
spheroid = None
|
||||
|
||||
|
@ -273,13 +276,18 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam):
|
|||
def as_postgresql(self, compiler, connection):
|
||||
function = None
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if self.source_is_geography():
|
||||
# Set parameters as geography if base field is geography
|
||||
for pos, expr in enumerate(
|
||||
self.source_expressions[self.geom_param_pos + 1:], start=self.geom_param_pos + 1):
|
||||
if isinstance(expr, GeomValue):
|
||||
expr.geography = True
|
||||
elif geo_field.geodetic(connection):
|
||||
expr2 = self.source_expressions[1]
|
||||
geography = self.source_is_geography()
|
||||
if expr2.output_field.geography != geography:
|
||||
if isinstance(expr2, Value):
|
||||
expr2.output_field.geography = geography
|
||||
else:
|
||||
self.source_expressions[1] = Cast(
|
||||
expr2,
|
||||
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
||||
)
|
||||
|
||||
if not geography and geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
|
||||
if self.spheroid:
|
||||
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
||||
|
@ -305,11 +313,11 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam):
|
|||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Envelope(GeoFunc):
|
||||
class Envelope(GeomOutputGeoFunc):
|
||||
arity = 1
|
||||
|
||||
|
||||
class ForceRHR(GeoFunc):
|
||||
class ForceRHR(GeomOutputGeoFunc):
|
||||
arity = 1
|
||||
|
||||
|
||||
|
@ -323,8 +331,9 @@ class GeoHash(GeoFunc):
|
|||
super().__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class Intersection(OracleToleranceMixin, GeoFuncWithGeoParam):
|
||||
class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
arity = 2
|
||||
geom_param_pos = (0, 1)
|
||||
|
||||
|
||||
@BaseSpatialField.register_lookup
|
||||
|
@ -392,7 +401,7 @@ class NumPoints(GeoFunc):
|
|||
arity = 1
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if self.source_expressions[self.geom_param_pos].output_field.geom_type != 'LINESTRING':
|
||||
if self.source_expressions[self.geom_param_pos[0]].output_field.geom_type != 'LINESTRING':
|
||||
if not connection.features.supports_num_points_poly:
|
||||
raise TypeError('NumPoints can only operate on LineString content on this database.')
|
||||
return super().as_sql(compiler, connection)
|
||||
|
@ -419,7 +428,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class PointOnSurface(OracleToleranceMixin, GeoFunc):
|
||||
class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
arity = 1
|
||||
|
||||
|
||||
|
@ -427,7 +436,7 @@ class Reverse(GeoFunc):
|
|||
arity = 1
|
||||
|
||||
|
||||
class Scale(SQLiteDecimalToFloatMixin, GeoFunc):
|
||||
class Scale(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
|
||||
def __init__(self, expression, x, y, z=0.0, **extra):
|
||||
expressions = [
|
||||
expression,
|
||||
|
@ -439,7 +448,7 @@ class Scale(SQLiteDecimalToFloatMixin, GeoFunc):
|
|||
super().__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class SnapToGrid(SQLiteDecimalToFloatMixin, GeoFunc):
|
||||
class SnapToGrid(SQLiteDecimalToFloatMixin, GeomOutputGeoFunc):
|
||||
def __init__(self, expression, *args, **extra):
|
||||
nargs = len(args)
|
||||
expressions = [expression]
|
||||
|
@ -460,11 +469,12 @@ class SnapToGrid(SQLiteDecimalToFloatMixin, GeoFunc):
|
|||
super().__init__(*expressions, **extra)
|
||||
|
||||
|
||||
class SymDifference(OracleToleranceMixin, GeoFuncWithGeoParam):
|
||||
class SymDifference(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
arity = 2
|
||||
geom_param_pos = (0, 1)
|
||||
|
||||
|
||||
class Transform(GeoFunc):
|
||||
class Transform(GeomOutputGeoFunc):
|
||||
def __init__(self, expression, srid, **extra):
|
||||
expressions = [
|
||||
expression,
|
||||
|
@ -477,7 +487,7 @@ class Transform(GeoFunc):
|
|||
@property
|
||||
def srid(self):
|
||||
# Make srid the resulting srid of the transformation
|
||||
return self.source_expressions[self.geom_param_pos + 1].value
|
||||
return self.source_expressions[1].value
|
||||
|
||||
|
||||
class Translate(Scale):
|
||||
|
@ -488,5 +498,6 @@ class Translate(Scale):
|
|||
return super().as_sqlite(compiler, connection)
|
||||
|
||||
|
||||
class Union(OracleToleranceMixin, GeoFuncWithGeoParam):
|
||||
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
arity = 2
|
||||
geom_param_pos = (0, 1)
|
||||
|
|
|
@ -76,6 +76,8 @@ class Lookup:
|
|||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
lhs = lhs or self.lhs
|
||||
if hasattr(lhs, 'resolve_expression'):
|
||||
lhs = lhs.resolve_expression(compiler.query)
|
||||
return compiler.compile(lhs)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
|
|
|
@ -429,6 +429,9 @@ class DistanceFunctionsTests(TestCase):
|
|||
self.assertTrue(
|
||||
SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
|
||||
)
|
||||
# Length with an explicit geometry value.
|
||||
qs = Interstate.objects.annotate(length=Length(i10.path))
|
||||
self.assertAlmostEqual(qs.first().length.m, len_m2, 2)
|
||||
|
||||
@skipUnlessDBFeature("has_Perimeter_function")
|
||||
def test_perimeter(self):
|
||||
|
|
|
@ -2,7 +2,9 @@ import re
|
|||
from decimal import Decimal
|
||||
|
||||
from django.contrib.gis.db.models import functions
|
||||
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
||||
from django.contrib.gis.geos import (
|
||||
GEOSGeometry, LineString, Point, Polygon, fromstr,
|
||||
)
|
||||
from django.contrib.gis.measure import Area
|
||||
from django.db import connection
|
||||
from django.db.models import Sum
|
||||
|
@ -494,7 +496,48 @@ class GISFunctionsTests(TestCase):
|
|||
|
||||
@skipUnlessDBFeature("has_Union_function")
|
||||
def test_union(self):
|
||||
"""Union with all combinations of geometries/geometry fields."""
|
||||
geom = Point(-95.363151, 29.763374, srid=4326)
|
||||
ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas')
|
||||
|
||||
union = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas').union
|
||||
expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
|
||||
self.assertTrue(expected.equals(ptown.union))
|
||||
self.assertTrue(expected.equals(union))
|
||||
|
||||
union = City.objects.annotate(union=functions.Union(geom, 'point')).get(name='Dallas').union
|
||||
self.assertTrue(expected.equals(union))
|
||||
|
||||
union = City.objects.annotate(union=functions.Union('point', 'point')).get(name='Dallas').union
|
||||
expected = GEOSGeometry('POINT(-96.801611 32.782057)', srid=4326)
|
||||
self.assertTrue(expected.equals(union))
|
||||
|
||||
union = City.objects.annotate(union=functions.Union(geom, geom)).get(name='Dallas').union
|
||||
self.assertTrue(geom.equals(union))
|
||||
|
||||
@skipUnlessDBFeature("has_Union_function", "has_Transform_function")
|
||||
def test_union_mixed_srid(self):
|
||||
"""The result SRID depends on the order of parameters."""
|
||||
geom = Point(61.42915, 55.15402, srid=4326)
|
||||
geom_3857 = geom.transform(3857, clone=True)
|
||||
tol = 0.001
|
||||
|
||||
for city in City.objects.annotate(union=functions.Union('point', geom_3857)):
|
||||
expected = city.point | geom
|
||||
self.assertTrue(city.union.equals_exact(expected, tol))
|
||||
self.assertEqual(city.union.srid, 4326)
|
||||
|
||||
for city in City.objects.annotate(union=functions.Union(geom_3857, 'point')):
|
||||
expected = geom_3857 | city.point.transform(3857, clone=True)
|
||||
self.assertTrue(expected.equals_exact(city.union, tol))
|
||||
self.assertEqual(city.union.srid, 3857)
|
||||
|
||||
def test_argument_validation(self):
|
||||
with self.assertRaisesMessage(ValueError, 'SRID is required for all geometries.'):
|
||||
City.objects.annotate(geo=functions.GeoFunc(Point(1, 1)))
|
||||
|
||||
msg = 'GeoFunc function requires a GeometryField in position 1, got CharField.'
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
City.objects.annotate(geo=functions.GeoFunc('name'))
|
||||
|
||||
msg = 'GeoFunc function requires a geometric argument in position 1.'
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
City.objects.annotate(union=functions.GeoFunc(1, 'point')).get(name='Dallas')
|
||||
|
|
|
@ -120,9 +120,19 @@ class GeographyFunctionTests(TestCase):
|
|||
else:
|
||||
ref_dists = [0, 4891.20, 8071.64, 9123.95]
|
||||
htown = City.objects.get(name='Houston')
|
||||
qs = Zipcode.objects.annotate(distance=Distance('poly', htown.point))
|
||||
qs = Zipcode.objects.annotate(
|
||||
distance=Distance('poly', htown.point),
|
||||
distance2=Distance(htown.point, 'poly'),
|
||||
)
|
||||
for z, ref in zip(qs, ref_dists):
|
||||
self.assertAlmostEqual(z.distance.m, ref, 2)
|
||||
|
||||
if postgis:
|
||||
# PostGIS casts geography to geometry when distance2 is calculated.
|
||||
ref_dists = [0, 4899.68, 8081.30, 9115.15]
|
||||
for z, ref in zip(qs, ref_dists):
|
||||
self.assertAlmostEqual(z.distance2.m, ref, 2)
|
||||
|
||||
if not spatialite:
|
||||
# Distance function combined with a lookup.
|
||||
hzip = Zipcode.objects.get(code='77002')
|
||||
|
|
|
@ -271,7 +271,7 @@ class RasterFieldTest(TransactionTestCase):
|
|||
|
||||
def test_isvalid_lookup_with_raster_error(self):
|
||||
qs = RasterModel.objects.filter(rast__isvalid=True)
|
||||
msg = 'Geometry functions not supported for raster fields.'
|
||||
msg = 'IsValid function requires a GeometryField in position 1, got RasterField.'
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
qs.count()
|
||||
|
||||
|
@ -336,11 +336,11 @@ class RasterFieldTest(TransactionTestCase):
|
|||
"""
|
||||
point = GEOSGeometry("SRID=3086;POINT (-697024.9213808845 683729.1705516104)")
|
||||
rast = GDALRaster(json.loads(JSON_RASTER))
|
||||
msg = "Please provide a geometry object."
|
||||
msg = "Distance function requires a geometric argument in position 2."
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
RasterModel.objects.annotate(distance_from_point=Distance("geom", rast))
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", rast))
|
||||
msg = "Geometry functions not supported for raster fields."
|
||||
msg = "Distance function requires a GeometryField in position 1, got RasterField."
|
||||
with self.assertRaisesMessage(TypeError, msg):
|
||||
RasterModel.objects.annotate(distance_from_point=Distance("rastprojected", point)).count()
|
||||
|
|
Loading…
Reference in New Issue