From 9290f15bb525c7c0d2d06bee055d177ec947a1f6 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Sat, 15 Jul 2017 14:40:27 +0500 Subject: [PATCH] Simplified GIS Funcs by using GeoFuncMixin.geo_field. --- django/contrib/gis/db/models/functions.py | 62 ++++++++--------------- 1 file changed, 20 insertions(+), 42 deletions(-) diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 907fe8e456..510d54408c 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -12,6 +12,7 @@ from django.db.models import ( ) from django.db.models.expressions import Func, Value from django.db.models.functions import Cast +from django.utils.functional import cached_property NUMERIC_TYPES = (int, float, Decimal) @@ -47,13 +48,9 @@ class GeoFuncMixin: def name(self): return self.__class__.__name__ - @property - def srid(self): - return self.source_expressions[self.geom_param_pos[0]].field.srid - - @property + @cached_property def geo_field(self): - return GeometryField(srid=self.srid) if self.srid else None + return self.source_expressions[self.geom_param_pos[0]].field def as_sql(self, compiler, connection, function=None, **extra_context): if not self.function and not function: @@ -74,7 +71,7 @@ class GeoFuncMixin: ) ) - base_srid = res.srid + base_srid = res.geo_field.srid for pos in self.geom_param_pos[1:]: expr = res.source_expressions[pos] expr_srid = expr.output_field.srid @@ -98,15 +95,9 @@ class GeoFunc(GeoFuncMixin, Func): class GeomOutputGeoFunc(GeoFunc): - def __init__(self, *expressions, **extra): - if 'output_field' not in extra: - extra['output_field'] = GeometryField() - super().__init__(*expressions, **extra) - - def resolve_expression(self, *args, **kwargs): - res = super().resolve_expression(*args, **kwargs) - res.output_field.srid = res.srid - return res + @cached_property + def output_field(self): + return self.geo_field class SQLiteDecimalToFloatMixin: @@ -138,15 +129,14 @@ class Area(OracleToleranceMixin, GeoFunc): self.output_field.area_att = 'sq_m' else: # Getting the area units of the geographic field. - geo_field = self.geo_field - if geo_field.geodetic(connection): + if self.geo_field.geodetic(connection): if connection.features.supports_area_geodetic: self.output_field.area_att = 'sq_m' else: # TODO: Do we want to support raw number areas for geodetic fields? raise NotImplementedError('Area on geodetic coordinate systems not supported.') else: - units_name = geo_field.units_name(connection) + units_name = self.geo_field.units_name(connection) if units_name: self.output_field.area_att = AreaMeasure.unit_attname(units_name) return super().as_sql(compiler, connection, **extra_context) @@ -250,16 +240,15 @@ class DistanceResultMixin: output_field_class = DistanceField def source_is_geography(self): - return self.get_source_fields()[0].geography and self.srid == 4326 + return self.geo_field.geography and self.geo_field.srid == 4326 def distance_att(self, connection): dist_att = None - geo_field = self.geo_field - if geo_field.geodetic(connection): + if self.geo_field.geodetic(connection): if connection.features.supports_distance_geodetic: dist_att = 'm' else: - units = geo_field.units_name(connection) + units = self.geo_field.units_name(connection) if units: dist_att = DistanceMeasure.unit_attname(units) return dist_att @@ -283,7 +272,6 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): def as_postgresql(self, compiler, connection): function = None - geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info expr2 = self.source_expressions[1] geography = self.source_is_geography() if expr2.output_field.geography != geography: @@ -295,13 +283,13 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): GeometryField(srid=expr2.output_field.srid, geography=geography), ) - if not geography and geo_field.geodetic(connection): + if not geography and self.geo_field.geodetic(connection): # Geometry fields with geodetic (lon/lat) coordinates need special distance functions if self.spheroid: # DistanceSpheroid is more accurate and resource intensive than DistanceSphere function = connection.ops.spatial_function_name('DistanceSpheroid') # Replace boolean param by the real spheroid of the base field - self.source_expressions[2] = Value(geo_field.spheroid(connection)) + self.source_expressions[2] = Value(self.geo_field.spheroid(connection)) else: function = connection.ops.spatial_function_name('DistanceSphere') return super().as_sql(compiler, connection, function=function) @@ -367,20 +355,18 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): super().__init__(expr1, **extra) def as_sql(self, compiler, connection, **extra_context): - geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info - if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic: + if self.geo_field.geodetic(connection) and not connection.features.supports_length_geodetic: raise NotImplementedError("This backend doesn't support Length on geodetic fields") return super().as_sql(compiler, connection, **extra_context) def as_postgresql(self, compiler, connection): function = None - geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if self.source_is_geography(): self.source_expressions.append(Value(self.spheroid)) - elif geo_field.geodetic(connection): + elif self.geo_field.geodetic(connection): # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid function = connection.ops.spatial_function_name('LengthSpheroid') - self.source_expressions.append(Value(geo_field.spheroid(connection))) + self.source_expressions.append(Value(self.geo_field.spheroid(connection))) else: dim = min(f.dim for f in self.get_source_fields() if f) if dim > 2: @@ -389,8 +375,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): def as_sqlite(self, compiler, connection): function = None - geo_field = GeometryField(srid=self.srid) - if geo_field.geodetic(connection): + if self.geo_field.geodetic(connection): function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' return super().as_sql(compiler, connection, function=function) @@ -425,8 +410,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): def as_postgresql(self, compiler, connection): function = None - geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info - if geo_field.geodetic(connection) and not self.source_is_geography(): + if self.geo_field.geodetic(connection) and not self.source_is_geography(): raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.") dim = min(f.dim for f in self.get_source_fields()) if dim > 2: @@ -434,8 +418,7 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): return super().as_sql(compiler, connection, function=function) def as_sqlite(self, compiler, connection): - geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info - if geo_field.geodetic(connection): + if self.geo_field.geodetic(connection): raise NotImplementedError("Perimeter cannot use a non-projected field.") return super().as_sql(compiler, connection) @@ -496,11 +479,6 @@ class Transform(GeomOutputGeoFunc): extra['output_field'] = GeometryField(srid=srid) super().__init__(*expressions, **extra) - @property - def srid(self): - # Make srid the resulting srid of the transformation - return self.source_expressions[1].value - class Translate(Scale): def as_sqlite(self, compiler, connection):