Simplified GIS Funcs by using GeoFuncMixin.geo_field.

This commit is contained in:
Sergey Fedoseev 2017-07-15 14:40:27 +05:00
parent 504ce3914f
commit 9290f15bb5
1 changed files with 20 additions and 42 deletions

View File

@ -12,6 +12,7 @@ from django.db.models import (
)
from django.db.models.expressions import Func, Value
from django.db.models.functions import Cast
from django.utils.functional import cached_property
NUMERIC_TYPES = (int, float, Decimal)
@ -47,13 +48,9 @@ class GeoFuncMixin:
def name(self):
return self.__class__.__name__
@property
def srid(self):
return self.source_expressions[self.geom_param_pos[0]].field.srid
@property
@cached_property
def geo_field(self):
return GeometryField(srid=self.srid) if self.srid else None
return self.source_expressions[self.geom_param_pos[0]].field
def as_sql(self, compiler, connection, function=None, **extra_context):
if not self.function and not function:
@ -74,7 +71,7 @@ class GeoFuncMixin:
)
)
base_srid = res.srid
base_srid = res.geo_field.srid
for pos in self.geom_param_pos[1:]:
expr = res.source_expressions[pos]
expr_srid = expr.output_field.srid
@ -98,15 +95,9 @@ class GeoFunc(GeoFuncMixin, Func):
class GeomOutputGeoFunc(GeoFunc):
def __init__(self, *expressions, **extra):
if 'output_field' not in extra:
extra['output_field'] = GeometryField()
super().__init__(*expressions, **extra)
def resolve_expression(self, *args, **kwargs):
res = super().resolve_expression(*args, **kwargs)
res.output_field.srid = res.srid
return res
@cached_property
def output_field(self):
return self.geo_field
class SQLiteDecimalToFloatMixin:
@ -138,15 +129,14 @@ class Area(OracleToleranceMixin, GeoFunc):
self.output_field.area_att = 'sq_m'
else:
# Getting the area units of the geographic field.
geo_field = self.geo_field
if geo_field.geodetic(connection):
if self.geo_field.geodetic(connection):
if connection.features.supports_area_geodetic:
self.output_field.area_att = 'sq_m'
else:
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
else:
units_name = geo_field.units_name(connection)
units_name = self.geo_field.units_name(connection)
if units_name:
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
return super().as_sql(compiler, connection, **extra_context)
@ -250,16 +240,15 @@ class DistanceResultMixin:
output_field_class = DistanceField
def source_is_geography(self):
return self.get_source_fields()[0].geography and self.srid == 4326
return self.geo_field.geography and self.geo_field.srid == 4326
def distance_att(self, connection):
dist_att = None
geo_field = self.geo_field
if geo_field.geodetic(connection):
if self.geo_field.geodetic(connection):
if connection.features.supports_distance_geodetic:
dist_att = 'm'
else:
units = geo_field.units_name(connection)
units = self.geo_field.units_name(connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
return dist_att
@ -283,7 +272,6 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def as_postgresql(self, compiler, connection):
function = None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
expr2 = self.source_expressions[1]
geography = self.source_is_geography()
if expr2.output_field.geography != geography:
@ -295,13 +283,13 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
GeometryField(srid=expr2.output_field.srid, geography=geography),
)
if not geography and geo_field.geodetic(connection):
if not geography and self.geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions
if self.spheroid:
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
function = connection.ops.spatial_function_name('DistanceSpheroid')
# Replace boolean param by the real spheroid of the base field
self.source_expressions[2] = Value(geo_field.spheroid(connection))
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
else:
function = connection.ops.spatial_function_name('DistanceSphere')
return super().as_sql(compiler, connection, function=function)
@ -367,20 +355,18 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
super().__init__(expr1, **extra)
def as_sql(self, compiler, connection, **extra_context):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection):
function = None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if self.source_is_geography():
self.source_expressions.append(Value(self.spheroid))
elif geo_field.geodetic(connection):
elif self.geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
function = connection.ops.spatial_function_name('LengthSpheroid')
self.source_expressions.append(Value(geo_field.spheroid(connection)))
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else:
dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2:
@ -389,8 +375,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def as_sqlite(self, compiler, connection):
function = None
geo_field = GeometryField(srid=self.srid)
if geo_field.geodetic(connection):
if self.geo_field.geodetic(connection):
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
return super().as_sql(compiler, connection, function=function)
@ -425,8 +410,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def as_postgresql(self, compiler, connection):
function = None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not self.source_is_geography():
if self.geo_field.geodetic(connection) and not self.source_is_geography():
raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.")
dim = min(f.dim for f in self.get_source_fields())
if dim > 2:
@ -434,8 +418,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
return super().as_sql(compiler, connection, function=function)
def as_sqlite(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection):
if self.geo_field.geodetic(connection):
raise NotImplementedError("Perimeter cannot use a non-projected field.")
return super().as_sql(compiler, connection)
@ -496,11 +479,6 @@ class Transform(GeomOutputGeoFunc):
extra['output_field'] = GeometryField(srid=srid)
super().__init__(*expressions, **extra)
@property
def srid(self):
# Make srid the resulting srid of the transformation
return self.source_expressions[1].value
class Translate(Scale):
def as_sqlite(self, compiler, connection):