diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 131ef0e676..3fa57cdc07 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -8,20 +8,21 @@ class GeoAggregate(Aggregate): function = None is_extent = False - def as_sql(self, compiler, connection): + def as_sql(self, compiler, connection, function=None, **extra_context): # this will be called again in parent, but it's needed now - before # we get the spatial_aggregate_name connection.ops.check_expression_support(self) - self.function = connection.ops.spatial_aggregate_name(self.name) - return super().as_sql(compiler, connection) + return super().as_sql( + compiler, + connection, + function=function or connection.ops.spatial_aggregate_name(self.name), + **extra_context + ) def as_oracle(self, compiler, connection): - if not hasattr(self, 'tolerance'): - self.tolerance = 0.05 - self.extra['tolerance'] = self.tolerance - if not self.is_extent: - self.template = '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' - return self.as_sql(compiler, connection) + tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05) + template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))' + return self.as_sql(compiler, connection, template=template, tolerance=tolerance) def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False): c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save) diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index fdcb562162..36d1644ddd 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -45,12 +45,12 @@ class GeoFuncMixin: def geo_field(self): return GeometryField(srid=self.srid) if self.srid else None - def as_sql(self, compiler, connection, **extra_context): - if self.function is None: - self.function = connection.ops.spatial_function_name(self.name) + def as_sql(self, compiler, connection, function=None, **extra_context): + if not self.function and not function: + function = connection.ops.spatial_function_name(self.name) if any(isinstance(field, RasterField) for field in self.get_source_fields()): raise TypeError("Geometry functions not supported for raster fields.") - return super().as_sql(compiler, connection, **extra_context) + return super().as_sql(compiler, connection, function=function, **extra_context) def resolve_expression(self, *args, **kwargs): res = super().resolve_expression(*args, **kwargs) @@ -125,8 +125,7 @@ class OracleToleranceMixin: def as_oracle(self, compiler, connection): tol = self.extra.get('tolerance', self.tolerance) - self.template = "%%(function)s(%%(expressions)s, %s)" % tol - return super().as_sql(compiler, connection) + return super().as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol) class Area(OracleToleranceMixin, GeoFunc): @@ -272,6 +271,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): super().__init__(*expressions, **extra) def as_postgresql(self, compiler, connection): + function = None geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if self.source_is_geography(): # Set parameters as geography if base field is geography @@ -283,12 +283,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFuncWithGeoParam): # Geometry fields with geodetic (lon/lat) coordinates need special distance functions if self.spheroid: # DistanceSpheroid is more accurate and resource intensive than DistanceSphere - self.function = connection.ops.spatial_function_name('DistanceSpheroid') + function = connection.ops.spatial_function_name('DistanceSpheroid') # Replace boolean param by the real spheroid of the base field self.source_expressions[2] = Value(geo_field._spheroid) else: - self.function = connection.ops.spatial_function_name('DistanceSphere') - return super().as_sql(compiler, connection) + function = connection.ops.spatial_function_name('DistanceSphere') + return super().as_sql(compiler, connection, function=function) def as_oracle(self, compiler, connection): if self.spheroid: @@ -351,27 +351,26 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): return super().as_sql(compiler, connection) def as_postgresql(self, compiler, connection): + function = None geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if self.source_is_geography(): self.source_expressions.append(Value(self.spheroid)) elif geo_field.geodetic(connection): # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid - self.function = connection.ops.spatial_function_name('LengthSpheroid') + function = connection.ops.spatial_function_name('LengthSpheroid') self.source_expressions.append(Value(geo_field._spheroid)) else: dim = min(f.dim for f in self.get_source_fields() if f) if dim > 2: - self.function = connection.ops.length3d - return super().as_sql(compiler, connection) + function = connection.ops.length3d + return super().as_sql(compiler, connection, function=function) def as_sqlite(self, compiler, connection): + function = None geo_field = GeometryField(srid=self.srid) if geo_field.geodetic(connection): - if self.spheroid: - self.function = 'GeodesicLength' - else: - self.function = 'GreatCircleLength' - return super().as_sql(compiler, connection) + function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength' + return super().as_sql(compiler, connection, function=function) class MakeValid(GeoFunc): @@ -404,13 +403,14 @@ class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): arity = 1 def as_postgresql(self, compiler, connection): + function = None geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info if geo_field.geodetic(connection) and not self.source_is_geography(): raise NotImplementedError("ST_Perimeter cannot use a non-projected non-geography field.") dim = min(f.dim for f in self.get_source_fields()) if dim > 2: - self.function = connection.ops.perimeter3d - return super().as_sql(compiler, connection) + function = connection.ops.perimeter3d + return super().as_sql(compiler, connection, function=function) def as_sqlite(self, compiler, connection): geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info