Simplified GIS Funcs by using GeoFuncMixin.geo_field.
This commit is contained in:
parent
504ce3914f
commit
9290f15bb5
|
@ -12,6 +12,7 @@ from django.db.models import (
|
|||
)
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NUMERIC_TYPES = (int, float, Decimal)
|
||||
|
||||
|
@ -47,13 +48,9 @@ class GeoFuncMixin:
|
|||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
return self.source_expressions[self.geom_param_pos[0]].field.srid
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def geo_field(self):
|
||||
return GeometryField(srid=self.srid) if self.srid else None
|
||||
return self.source_expressions[self.geom_param_pos[0]].field
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, **extra_context):
|
||||
if not self.function and not function:
|
||||
|
@ -74,7 +71,7 @@ class GeoFuncMixin:
|
|||
)
|
||||
)
|
||||
|
||||
base_srid = res.srid
|
||||
base_srid = res.geo_field.srid
|
||||
for pos in self.geom_param_pos[1:]:
|
||||
expr = res.source_expressions[pos]
|
||||
expr_srid = expr.output_field.srid
|
||||
|
@ -98,15 +95,9 @@ class GeoFunc(GeoFuncMixin, Func):
|
|||
|
||||
|
||||
class GeomOutputGeoFunc(GeoFunc):
|
||||
def __init__(self, *expressions, **extra):
|
||||
if 'output_field' not in extra:
|
||||
extra['output_field'] = GeometryField()
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
res = super().resolve_expression(*args, **kwargs)
|
||||
res.output_field.srid = res.srid
|
||||
return res
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return self.geo_field
|
||||
|
||||
|
||||
class SQLiteDecimalToFloatMixin:
|
||||
|
@ -138,15 +129,14 @@ class Area(OracleToleranceMixin, GeoFunc):
|
|||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# Getting the area units of the geographic field.
|
||||
geo_field = self.geo_field
|
||||
if geo_field.geodetic(connection):
|
||||
if self.geo_field.geodetic(connection):
|
||||
if connection.features.supports_area_geodetic:
|
||||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
else:
|
||||
units_name = geo_field.units_name(connection)
|
||||
units_name = self.geo_field.units_name(connection)
|
||||
if units_name:
|
||||
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
@ -250,16 +240,15 @@ class DistanceResultMixin:
|
|||
output_field_class = DistanceField
|
||||
|
||||
def source_is_geography(self):
|
||||
return self.get_source_fields()[0].geography and self.srid == 4326
|
||||
return self.geo_field.geography and self.geo_field.srid == 4326
|
||||
|
||||
def distance_att(self, connection):
|
||||
dist_att = None
|
||||
geo_field = self.geo_field
|
||||
if geo_field.geodetic(connection):
|
||||
if self.geo_field.geodetic(connection):
|
||||
if connection.features.supports_distance_geodetic:
|
||||
dist_att = 'm'
|
||||
else:
|
||||
units = geo_field.units_name(connection)
|
||||
units = self.geo_field.units_name(connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
return dist_att
|
||||
|
@ -283,7 +272,6 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
function = None
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
expr2 = self.source_expressions[1]
|
||||
geography = self.source_is_geography()
|
||||
if expr2.output_field.geography != geography:
|
||||
|
@ -295,13 +283,13 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
||||
)
|
||||
|
||||
if not geography and geo_field.geodetic(connection):
|
||||
if not geography and self.geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
|
||||
if self.spheroid:
|
||||
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
||||
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
||||
# Replace boolean param by the real spheroid of the base field
|
||||
self.source_expressions[2] = Value(geo_field.spheroid(connection))
|
||||
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
|
||||
else:
|
||||
function = connection.ops.spatial_function_name('DistanceSphere')
|
||||
return super().as_sql(compiler, connection, function=function)
|
||||
|
@ -367,20 +355,18 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
super().__init__(expr1, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||
if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
function = None
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if self.source_is_geography():
|
||||
self.source_expressions.append(Value(self.spheroid))
|
||||
elif geo_field.geodetic(connection):
|
||||
elif self.geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
||||
function = connection.ops.spatial_function_name('LengthSpheroid')
|
||||
self.source_expressions.append(Value(geo_field.spheroid(connection)))
|
||||
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
else:
|
||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||
if dim > 2:
|
||||
|
@ -389,8 +375,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
function = None
|
||||
geo_field = GeometryField(srid=self.srid)
|
||||
if geo_field.geodetic(connection):
|
||||
if self.geo_field.geodetic(connection):
|
||||
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
|
||||
return super().as_sql(compiler, connection, function=function)
|
||||
|
||||
|
@ -425,8 +410,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
function = None
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if geo_field.geodetic(connection) and not self.source_is_geography():
|
||||
if self.geo_field.geodetic(connection) and not self.source_is_geography():
|
||||
raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.")
|
||||
dim = min(f.dim for f in self.get_source_fields())
|
||||
if dim > 2:
|
||||
|
@ -434,8 +418,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
return super().as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if geo_field.geodetic(connection):
|
||||
if self.geo_field.geodetic(connection):
|
||||
raise NotImplementedError("Perimeter cannot use a non-projected field.")
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
@ -496,11 +479,6 @@ class Transform(GeomOutputGeoFunc):
|
|||
extra['output_field'] = GeometryField(srid=srid)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def srid(self):
|
||||
# Make srid the resulting srid of the transformation
|
||||
return self.source_expressions[1].value
|
||||
|
||||
|
||||
class Translate(Scale):
|
||||
def as_sqlite(self, compiler, connection):
|
||||
|
|
Loading…
Reference in New Issue