Simplified GIS Funcs by using GeoFuncMixin.geo_field.

This commit is contained in:
Sergey Fedoseev 2017-07-15 14:40:27 +05:00
parent 504ce3914f
commit 9290f15bb5
1 changed files with 20 additions and 42 deletions

View File

@ -12,6 +12,7 @@ from django.db.models import (
) )
from django.db.models.expressions import Func, Value from django.db.models.expressions import Func, Value
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.utils.functional import cached_property
NUMERIC_TYPES = (int, float, Decimal) NUMERIC_TYPES = (int, float, Decimal)
@ -47,13 +48,9 @@ class GeoFuncMixin:
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
@property @cached_property
def srid(self):
return self.source_expressions[self.geom_param_pos[0]].field.srid
@property
def geo_field(self): def geo_field(self):
return GeometryField(srid=self.srid) if self.srid else None return self.source_expressions[self.geom_param_pos[0]].field
def as_sql(self, compiler, connection, function=None, **extra_context): def as_sql(self, compiler, connection, function=None, **extra_context):
if not self.function and not function: if not self.function and not function:
@ -74,7 +71,7 @@ class GeoFuncMixin:
) )
) )
base_srid = res.srid base_srid = res.geo_field.srid
for pos in self.geom_param_pos[1:]: for pos in self.geom_param_pos[1:]:
expr = res.source_expressions[pos] expr = res.source_expressions[pos]
expr_srid = expr.output_field.srid expr_srid = expr.output_field.srid
@ -98,15 +95,9 @@ class GeoFunc(GeoFuncMixin, Func):
class GeomOutputGeoFunc(GeoFunc): class GeomOutputGeoFunc(GeoFunc):
def __init__(self, *expressions, **extra): @cached_property
if 'output_field' not in extra: def output_field(self):
extra['output_field'] = GeometryField() return self.geo_field
super().__init__(*expressions, **extra)
def resolve_expression(self, *args, **kwargs):
res = super().resolve_expression(*args, **kwargs)
res.output_field.srid = res.srid
return res
class SQLiteDecimalToFloatMixin: class SQLiteDecimalToFloatMixin:
@ -138,15 +129,14 @@ class Area(OracleToleranceMixin, GeoFunc):
self.output_field.area_att = 'sq_m' self.output_field.area_att = 'sq_m'
else: else:
# Getting the area units of the geographic field. # Getting the area units of the geographic field.
geo_field = self.geo_field if self.geo_field.geodetic(connection):
if geo_field.geodetic(connection):
if connection.features.supports_area_geodetic: if connection.features.supports_area_geodetic:
self.output_field.area_att = 'sq_m' self.output_field.area_att = 'sq_m'
else: else:
# TODO: Do we want to support raw number areas for geodetic fields? # TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.') raise NotImplementedError('Area on geodetic coordinate systems not supported.')
else: else:
units_name = geo_field.units_name(connection) units_name = self.geo_field.units_name(connection)
if units_name: if units_name:
self.output_field.area_att = AreaMeasure.unit_attname(units_name) self.output_field.area_att = AreaMeasure.unit_attname(units_name)
return super().as_sql(compiler, connection, **extra_context) return super().as_sql(compiler, connection, **extra_context)
@ -250,16 +240,15 @@ class DistanceResultMixin:
output_field_class = DistanceField output_field_class = DistanceField
def source_is_geography(self): def source_is_geography(self):
return self.get_source_fields()[0].geography and self.srid == 4326 return self.geo_field.geography and self.geo_field.srid == 4326
def distance_att(self, connection): def distance_att(self, connection):
dist_att = None dist_att = None
geo_field = self.geo_field if self.geo_field.geodetic(connection):
if geo_field.geodetic(connection):
if connection.features.supports_distance_geodetic: if connection.features.supports_distance_geodetic:
dist_att = 'm' dist_att = 'm'
else: else:
units = geo_field.units_name(connection) units = self.geo_field.units_name(connection)
if units: if units:
dist_att = DistanceMeasure.unit_attname(units) dist_att = DistanceMeasure.unit_attname(units)
return dist_att return dist_att
@ -283,7 +272,6 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
function = None function = None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
expr2 = self.source_expressions[1] expr2 = self.source_expressions[1]
geography = self.source_is_geography() geography = self.source_is_geography()
if expr2.output_field.geography != geography: if expr2.output_field.geography != geography:
@ -295,13 +283,13 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
GeometryField(srid=expr2.output_field.srid, geography=geography), GeometryField(srid=expr2.output_field.srid, geography=geography),
) )
if not geography and geo_field.geodetic(connection): if not geography and self.geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need special distance functions # Geometry fields with geodetic (lon/lat) coordinates need special distance functions
if self.spheroid: if self.spheroid:
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere # DistanceSpheroid is more accurate and resource intensive than DistanceSphere
function = connection.ops.spatial_function_name('DistanceSpheroid') function = connection.ops.spatial_function_name('DistanceSpheroid')
# Replace boolean param by the real spheroid of the base field # Replace boolean param by the real spheroid of the base field
self.source_expressions[2] = Value(geo_field.spheroid(connection)) self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
else: else:
function = connection.ops.spatial_function_name('DistanceSphere') function = connection.ops.spatial_function_name('DistanceSphere')
return super().as_sql(compiler, connection, function=function) return super().as_sql(compiler, connection, function=function)
@ -367,20 +355,18 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
super().__init__(expr1, **extra) super().__init__(expr1, **extra)
def as_sql(self, compiler, connection, **extra_context): def as_sql(self, compiler, connection, **extra_context):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields") raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super().as_sql(compiler, connection, **extra_context) return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
function = None function = None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if self.source_is_geography(): if self.source_is_geography():
self.source_expressions.append(Value(self.spheroid)) self.source_expressions.append(Value(self.spheroid))
elif geo_field.geodetic(connection): elif self.geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
function = connection.ops.spatial_function_name('LengthSpheroid') function = connection.ops.spatial_function_name('LengthSpheroid')
self.source_expressions.append(Value(geo_field.spheroid(connection))) self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else: else:
dim = min(f.dim for f in self.get_source_fields() if f) dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2: if dim > 2:
@ -389,8 +375,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):
function = None function = None
geo_field = GeometryField(srid=self.srid) if self.geo_field.geodetic(connection):
if geo_field.geodetic(connection):
function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
return super().as_sql(compiler, connection, function=function) return super().as_sql(compiler, connection, function=function)
@ -425,8 +410,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
function = None function = None
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if self.geo_field.geodetic(connection) and not self.source_is_geography():
if geo_field.geodetic(connection) and not self.source_is_geography():
raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.") raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.")
dim = min(f.dim for f in self.get_source_fields()) dim = min(f.dim for f in self.get_source_fields())
if dim > 2: if dim > 2:
@ -434,8 +418,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
return super().as_sql(compiler, connection, function=function) return super().as_sql(compiler, connection, function=function)
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if self.geo_field.geodetic(connection):
if geo_field.geodetic(connection):
raise NotImplementedError("Perimeter cannot use a non-projected field.") raise NotImplementedError("Perimeter cannot use a non-projected field.")
return super().as_sql(compiler, connection) return super().as_sql(compiler, connection)
@ -496,11 +479,6 @@ class Transform(GeomOutputGeoFunc):
extra['output_field'] = GeometryField(srid=srid) extra['output_field'] = GeometryField(srid=srid)
super().__init__(*expressions, **extra) super().__init__(*expressions, **extra)
@property
def srid(self):
# Make srid the resulting srid of the transformation
return self.source_expressions[1].value
class Translate(Scale): class Translate(Scale):
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):