Simplified GIS Funcs by using GeoFuncMixin.geo_field.
This commit is contained in:
parent
504ce3914f
commit
9290f15bb5
|
@ -12,6 +12,7 @@ from django.db.models import (
|
||||||
)
|
)
|
||||||
from django.db.models.expressions import Func, Value
|
from django.db.models.expressions import Func, Value
|
||||||
from django.db.models.functions import Cast
|
from django.db.models.functions import Cast
|
||||||
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
NUMERIC_TYPES = (int, float, Decimal)
|
NUMERIC_TYPES = (int, float, Decimal)
|
||||||
|
|
||||||
|
@ -47,13 +48,9 @@ class GeoFuncMixin:
|
||||||
def name(self):
|
def name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def srid(self):
|
|
||||||
return self.source_expressions[self.geom_param_pos[0]].field.srid
|
|
||||||
|
|
||||||
@property
|
|
||||||
def geo_field(self):
|
def geo_field(self):
|
||||||
return GeometryField(srid=self.srid) if self.srid else None
|
return self.source_expressions[self.geom_param_pos[0]].field
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, function=None, **extra_context):
|
def as_sql(self, compiler, connection, function=None, **extra_context):
|
||||||
if not self.function and not function:
|
if not self.function and not function:
|
||||||
|
@ -74,7 +71,7 @@ class GeoFuncMixin:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
base_srid = res.srid
|
base_srid = res.geo_field.srid
|
||||||
for pos in self.geom_param_pos[1:]:
|
for pos in self.geom_param_pos[1:]:
|
||||||
expr = res.source_expressions[pos]
|
expr = res.source_expressions[pos]
|
||||||
expr_srid = expr.output_field.srid
|
expr_srid = expr.output_field.srid
|
||||||
|
@ -98,15 +95,9 @@ class GeoFunc(GeoFuncMixin, Func):
|
||||||
|
|
||||||
|
|
||||||
class GeomOutputGeoFunc(GeoFunc):
|
class GeomOutputGeoFunc(GeoFunc):
|
||||||
def __init__(self, *expressions, **extra):
|
@cached_property
|
||||||
if 'output_field' not in extra:
|
def output_field(self):
|
||||||
extra['output_field'] = GeometryField()
|
return self.geo_field
|
||||||
super().__init__(*expressions, **extra)
|
|
||||||
|
|
||||||
def resolve_expression(self, *args, **kwargs):
|
|
||||||
res = super().resolve_expression(*args, **kwargs)
|
|
||||||
res.output_field.srid = res.srid
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDecimalToFloatMixin:
|
class SQLiteDecimalToFloatMixin:
|
||||||
|
@ -138,15 +129,14 @@ class Area(OracleToleranceMixin, GeoFunc):
|
||||||
self.output_field.area_att = 'sq_m'
|
self.output_field.area_att = 'sq_m'
|
||||||
else:
|
else:
|
||||||
# Getting the area units of the geographic field.
|
# Getting the area units of the geographic field.
|
||||||
geo_field = self.geo_field
|
if self.geo_field.geodetic(connection):
|
||||||
if geo_field.geodetic(connection):
|
|
||||||
if connection.features.supports_area_geodetic:
|
if connection.features.supports_area_geodetic:
|
||||||
self.output_field.area_att = 'sq_m'
|
self.output_field.area_att = 'sq_m'
|
||||||
else:
|
else:
|
||||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||||
else:
|
else:
|
||||||
units_name = geo_field.units_name(connection)
|
units_name = self.geo_field.units_name(connection)
|
||||||
if units_name:
|
if units_name:
|
||||||
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
@ -250,16 +240,15 @@ class DistanceResultMixin:
|
||||||
output_field_class = DistanceField
|
output_field_class = DistanceField
|
||||||
|
|
||||||
def source_is_geography(self):
|
def source_is_geography(self):
|
||||||
return self.get_source_fields()[0].geography and self.srid == 4326
|
return self.geo_field.geography and self.geo_field.srid == 4326
|
||||||
|
|
||||||
def distance_att(self, connection):
|
def distance_att(self, connection):
|
||||||
dist_att = None
|
dist_att = None
|
||||||
geo_field = self.geo_field
|
if self.geo_field.geodetic(connection):
|
||||||
if geo_field.geodetic(connection):
|
|
||||||
if connection.features.supports_distance_geodetic:
|
if connection.features.supports_distance_geodetic:
|
||||||
dist_att = 'm'
|
dist_att = 'm'
|
||||||
else:
|
else:
|
||||||
units = geo_field.units_name(connection)
|
units = self.geo_field.units_name(connection)
|
||||||
if units:
|
if units:
|
||||||
dist_att = DistanceMeasure.unit_attname(units)
|
dist_att = DistanceMeasure.unit_attname(units)
|
||||||
return dist_att
|
return dist_att
|
||||||
|
@ -283,7 +272,6 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
function = None
|
function = None
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
|
||||||
expr2 = self.source_expressions[1]
|
expr2 = self.source_expressions[1]
|
||||||
geography = self.source_is_geography()
|
geography = self.source_is_geography()
|
||||||
if expr2.output_field.geography != geography:
|
if expr2.output_field.geography != geography:
|
||||||
|
@ -295,13 +283,13 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not geography and geo_field.geodetic(connection):
|
if not geography and self.geo_field.geodetic(connection):
|
||||||
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
|
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
|
||||||
if self.spheroid:
|
if self.spheroid:
|
||||||
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
||||||
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
||||||
# Replace boolean param by the real spheroid of the base field
|
# Replace boolean param by the real spheroid of the base field
|
||||||
self.source_expressions[2] = Value(geo_field.spheroid(connection))
|
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
|
||||||
else:
|
else:
|
||||||
function = connection.ops.spatial_function_name('DistanceSphere')
|
function = connection.ops.spatial_function_name('DistanceSphere')
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super().as_sql(compiler, connection, function=function)
|
||||||
|
@ -367,20 +355,18 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
super().__init__(expr1, **extra)
|
super().__init__(expr1, **extra)
|
||||||
|
|
||||||
def as_sql(self, compiler, connection, **extra_context):
|
def as_sql(self, compiler, connection, **extra_context):
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||||
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
|
||||||
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
||||||
return super().as_sql(compiler, connection, **extra_context)
|
return super().as_sql(compiler, connection, **extra_context)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
function = None
|
function = None
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
|
||||||
if self.source_is_geography():
|
if self.source_is_geography():
|
||||||
self.source_expressions.append(Value(self.spheroid))
|
self.source_expressions.append(Value(self.spheroid))
|
||||||
elif geo_field.geodetic(connection):
|
elif self.geo_field.geodetic(connection):
|
||||||
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
||||||
function = connection.ops.spatial_function_name('LengthSpheroid')
|
function = connection.ops.spatial_function_name('LengthSpheroid')
|
||||||
self.source_expressions.append(Value(geo_field.spheroid(connection)))
|
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||||
else:
|
else:
|
||||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||||
if dim > 2:
|
if dim > 2:
|
||||||
|
@ -389,8 +375,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
function = None
|
function = None
|
||||||
geo_field = GeometryField(srid=self.srid)
|
if self.geo_field.geodetic(connection):
|
||||||
if geo_field.geodetic(connection):
|
|
||||||
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
|
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super().as_sql(compiler, connection, function=function)
|
||||||
|
|
||||||
|
@ -425,8 +410,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
function = None
|
function = None
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
if self.geo_field.geodetic(connection) and not self.source_is_geography():
|
||||||
if geo_field.geodetic(connection) and not self.source_is_geography():
|
|
||||||
raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.")
|
raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.")
|
||||||
dim = min(f.dim for f in self.get_source_fields())
|
dim = min(f.dim for f in self.get_source_fields())
|
||||||
if dim > 2:
|
if dim > 2:
|
||||||
|
@ -434,8 +418,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||||
return super().as_sql(compiler, connection, function=function)
|
return super().as_sql(compiler, connection, function=function)
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
if self.geo_field.geodetic(connection):
|
||||||
if geo_field.geodetic(connection):
|
|
||||||
raise NotImplementedError("Perimeter cannot use a non-projected field.")
|
raise NotImplementedError("Perimeter cannot use a non-projected field.")
|
||||||
return super().as_sql(compiler, connection)
|
return super().as_sql(compiler, connection)
|
||||||
|
|
||||||
|
@ -496,11 +479,6 @@ class Transform(GeomOutputGeoFunc):
|
||||||
extra['output_field'] = GeometryField(srid=srid)
|
extra['output_field'] = GeometryField(srid=srid)
|
||||||
super().__init__(*expressions, **extra)
|
super().__init__(*expressions, **extra)
|
||||||
|
|
||||||
@property
|
|
||||||
def srid(self):
|
|
||||||
# Make srid the resulting srid of the transformation
|
|
||||||
return self.source_expressions[1].value
|
|
||||||
|
|
||||||
|
|
||||||
class Translate(Scale):
|
class Translate(Scale):
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
|
|
Loading…
Reference in New Issue