diff --git a/django/contrib/gis/db/backends/base/operations.py b/django/contrib/gis/db/backends/base/operations.py index 170775f063..af85b83df6 100644 --- a/django/contrib/gis/db/backends/base/operations.py +++ b/django/contrib/gis/db/backends/base/operations.py @@ -1,5 +1,8 @@ from django.contrib.gis.db.models import GeometryField from django.contrib.gis.db.models.functions import Distance +from django.contrib.gis.measure import ( + Area as AreaMeasure, Distance as DistanceMeasure, +) from django.utils.functional import cached_property @@ -135,3 +138,24 @@ class BaseSpatialOperations: 'Subclasses of BaseSpatialOperations must provide a ' 'get_geometry_converter() method.' ) + + def get_area_att_for_field(self, field): + if field.geodetic(self.connection): + if self.connection.features.supports_area_geodetic: + return 'sq_m' + raise NotImplementedError('Area on geodetic coordinate systems not supported.') + else: + units_name = field.units_name(self.connection) + if units_name: + return AreaMeasure.unit_attname(units_name) + + def get_distance_att_for_field(self, field): + dist_att = None + if field.geodetic(self.connection): + if self.connection.features.supports_distance_geodetic: + dist_att = 'm' + else: + units = field.units_name(self.connection) + if units: + dist_att = DistanceMeasure.unit_attname(units) + return dist_att diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index e69b8b8c07..19cf2eaa06 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -212,3 +212,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): geom.srid = srid return geom return converter + + def get_area_att_for_field(self, field): + return 'sq_m' diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index a535b49849..51c8d5006e 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -389,3 +389,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): def converter(value, expression, connection): return None if value is None else GEOSGeometryBase(read(value), geom_class) return converter + + def get_area_att_for_field(self, field): + return 'sq_m' diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index e166c21e29..9d245b93c4 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -3,9 +3,6 @@ from decimal import Decimal from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField from django.contrib.gis.db.models.sql import AreaField, DistanceField from django.contrib.gis.geometry.backend import Geometry -from django.contrib.gis.measure import ( - Area as AreaMeasure, Distance as DistanceMeasure, -) from django.core.exceptions import FieldError from django.db.models import ( BooleanField, FloatField, IntegerField, TextField, Transform, @@ -121,29 +118,16 @@ class OracleToleranceMixin: class Area(OracleToleranceMixin, GeoFunc): - output_field_class = AreaField arity = 1 - def as_sql(self, compiler, connection, **extra_context): - if connection.ops.geography: - self.output_field.area_att = 'sq_m' - else: - # Getting the area units of the geographic field. - if self.geo_field.geodetic(connection): - if connection.features.supports_area_geodetic: - self.output_field.area_att = 'sq_m' - else: - # TODO: Do we want to support raw number areas for geodetic fields? - raise NotImplementedError('Area on geodetic coordinate systems not supported.') - else: - units_name = self.geo_field.units_name(connection) - if units_name: - self.output_field.area_att = AreaMeasure.unit_attname(units_name) - return super().as_sql(compiler, connection, **extra_context) + @cached_property + def output_field(self): + return AreaField(self.geo_field) - def as_oracle(self, compiler, connection): - self.output_field = AreaField('sq_m') # Oracle returns area in units of meters. - return super().as_oracle(compiler, connection) + def as_sql(self, compiler, connection, **extra_context): + if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection): + raise NotImplementedError('Area on geodetic coordinate systems not supported.') + return super().as_sql(compiler, connection, **extra_context) def as_sqlite(self, compiler, connection, **extra_context): if self.geo_field.geodetic(connection): @@ -237,27 +221,13 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc): class DistanceResultMixin: - output_field_class = DistanceField + @cached_property + def output_field(self): + return DistanceField(self.geo_field) def source_is_geography(self): return self.geo_field.geography and self.geo_field.srid == 4326 - def distance_att(self, connection): - dist_att = None - if self.geo_field.geodetic(connection): - if connection.features.supports_distance_geodetic: - dist_att = 'm' - else: - units = self.geo_field.units_name(connection) - if units: - dist_att = DistanceMeasure.unit_attname(units) - return dist_att - - def as_sql(self, compiler, connection, **extra_context): - clone = self.copy() - clone.output_field.distance_att = self.distance_att(connection) - return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context) - class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): geom_param_pos = (0, 1) @@ -266,19 +236,19 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): def __init__(self, expr1, expr2, spheroid=None, **extra): expressions = [expr1, expr2] if spheroid is not None: - self.spheroid = spheroid - expressions += (self._handle_param(spheroid, 'spheroid', bool),) + self.spheroid = self._handle_param(spheroid, 'spheroid', bool) super().__init__(*expressions, **extra) def as_postgresql(self, compiler, connection): + clone = self.copy() function = None - expr2 = self.source_expressions[1] + expr2 = clone.source_expressions[1] geography = self.source_is_geography() if expr2.output_field.geography != geography: if isinstance(expr2, Value): expr2.output_field.geography = geography else: - self.source_expressions[1] = Cast( + clone.source_expressions[1] = Cast( expr2, GeometryField(srid=expr2.output_field.srid, geography=geography), ) @@ -289,19 +259,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): # DistanceSpheroid is more accurate and resource intensive than DistanceSphere function = connection.ops.spatial_function_name('DistanceSpheroid') # Replace boolean param by the real spheroid of the base field - self.source_expressions[2] = Value(self.geo_field.spheroid(connection)) + clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) else: function = connection.ops.spatial_function_name('DistanceSphere') - return super().as_sql(compiler, connection, function=function) - - def as_oracle(self, compiler, connection): - if self.spheroid: - self.source_expressions.pop(2) - return super().as_oracle(compiler, connection) + return super(Distance, clone).as_sql(compiler, connection, function=function) def as_sqlite(self, compiler, connection, **extra_context): - if self.spheroid: - self.source_expressions.pop(2) if self.geo_field.geodetic(connection): # SpatiaLite returns NULL instead of zero on geodetic coordinates extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)' @@ -360,18 +323,19 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): return super().as_sql(compiler, connection, **extra_context) def as_postgresql(self, compiler, connection): + clone = self.copy() function = None if self.source_is_geography(): - self.source_expressions.append(Value(self.spheroid)) + clone.source_expressions.append(Value(self.spheroid)) elif self.geo_field.geodetic(connection): # Geometry fields with geodetic (lon/lat) coordinates need length_spheroid function = connection.ops.spatial_function_name('LengthSpheroid') - self.source_expressions.append(Value(self.geo_field.spheroid(connection))) + clone.source_expressions.append(Value(self.geo_field.spheroid(connection))) else: dim = min(f.dim for f in self.get_source_fields() if f) if dim > 2: function = connection.ops.length3d - return super().as_sql(compiler, connection, function=function) + return super(Length, clone).as_sql(compiler, connection, function=function) def as_sqlite(self, compiler, connection): function = None @@ -482,10 +446,11 @@ class Transform(GeomOutputGeoFunc): class Translate(Scale): def as_sqlite(self, compiler, connection): + clone = self.copy() if len(self.source_expressions) < 4: # Always provide the z parameter for ST_Translate - self.source_expressions.append(Value(0)) - return super().as_sqlite(compiler, connection) + clone.source_expressions.append(Value(0)) + return super(Translate, clone).as_sqlite(compiler, connection) class Union(OracleToleranceMixin, GeomOutputGeoFunc): diff --git a/django/contrib/gis/db/models/sql/conversion.py b/django/contrib/gis/db/models/sql/conversion.py index eafa7b5d17..99ab51e239 100644 --- a/django/contrib/gis/db/models/sql/conversion.py +++ b/django/contrib/gis/db/models/sql/conversion.py @@ -10,9 +10,9 @@ from django.db import models class AreaField(models.FloatField): "Wrapper for Area values." - def __init__(self, area_att=None): + def __init__(self, geo_field): super().__init__() - self.area_att = area_att + self.geo_field = geo_field def get_prep_value(self, value): if not isinstance(value, Area): @@ -20,19 +20,21 @@ class AreaField(models.FloatField): return value def get_db_prep_value(self, value, connection, prepared=False): - if value is None or not self.area_att: - return value - return getattr(value, self.area_att) + if value is None: + return + area_att = connection.ops.get_area_att_for_field(self.geo_field) + return getattr(value, area_att) if area_att else value def from_db_value(self, value, expression, connection): + if value is None: + return # If the database returns a Decimal, convert it to a float as expected # by the Python geometric objects. if isinstance(value, Decimal): value = float(value) # If the units are known, convert value into area measure. - if value is not None and self.area_att: - value = Area(**{self.area_att: value}) - return value + area_att = connection.ops.get_area_att_for_field(self.geo_field) + return Area(**{area_att: value}) if area_att else value def get_internal_type(self): return 'AreaField' @@ -40,9 +42,9 @@ class AreaField(models.FloatField): class DistanceField(models.FloatField): "Wrapper for Distance values." - def __init__(self, distance_att=None): + def __init__(self, geo_field): super().__init__() - self.distance_att = distance_att + self.geo_field = geo_field def get_prep_value(self, value): if isinstance(value, Distance): @@ -52,14 +54,16 @@ class DistanceField(models.FloatField): def get_db_prep_value(self, value, connection, prepared=False): if not isinstance(value, Distance): return value - if not self.distance_att: + distance_att = connection.ops.get_distance_att_for_field(self.geo_field) + if not distance_att: raise ValueError('Distance measure is supplied, but units are unknown for result.') - return getattr(value, self.distance_att) + return getattr(value, distance_att) def from_db_value(self, value, expression, connection): - if value is None or not self.distance_att: - return value - return Distance(**{self.distance_att: value}) + if value is None: + return + distance_att = connection.ops.get_distance_att_for_field(self.geo_field) + return Distance(**{distance_att: value}) if distance_att else value def get_internal_type(self): return 'DistanceField' diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index 395f7226ef..d162759513 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -9,7 +9,9 @@ from django.db import connection from django.db.models import F, Q from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature -from ..utils import mysql, no_oracle, oracle, postgis, spatialite +from ..utils import ( + FuncTestMixin, mysql, no_oracle, oracle, postgis, spatialite, +) from .models import ( AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt, SouthTexasInterstate, SouthTexasZipcode, @@ -262,7 +264,7 @@ Perimeter(geom1) | OK | :-( ''' # NOQA -class DistanceFunctionsTests(TestCase): +class DistanceFunctionsTests(FuncTestMixin, TestCase): fixtures = ['initial'] @skipUnlessDBFeature("has_Area_function") diff --git a/tests/gis_tests/geo3d/tests.py b/tests/gis_tests/geo3d/tests.py index 39603d1249..d2e85f0607 100644 --- a/tests/gis_tests/geo3d/tests.py +++ b/tests/gis_tests/geo3d/tests.py @@ -8,6 +8,7 @@ from django.contrib.gis.db.models.functions import ( from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon from django.test import TestCase, skipUnlessDBFeature +from ..utils import FuncTestMixin from .models import ( City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D, MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D, @@ -205,7 +206,7 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase): @skipUnlessDBFeature("supports_3d_functions") -class Geo3DFunctionsTests(Geo3DLoadingHelper, TestCase): +class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase): def test_kml(self): """ Test KML() function with Z values. diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py index bb13d9e37f..cdd05d78ff 100644 --- a/tests/gis_tests/geoapp/test_functions.py +++ b/tests/gis_tests/geoapp/test_functions.py @@ -12,11 +12,11 @@ from django.db import connection from django.db.models import Sum from django.test import TestCase, skipUnlessDBFeature -from ..utils import mysql, oracle, postgis, spatialite +from ..utils import FuncTestMixin, mysql, oracle, postgis, spatialite from .models import City, Country, CountryWebMercator, State, Track -class GISFunctionsTests(TestCase): +class GISFunctionsTests(FuncTestMixin, TestCase): """ Testing functions from django/contrib/gis/db/models/functions.py. Area/Distance/Length/Perimeter are tested in distapp/tests. @@ -127,11 +127,8 @@ class GISFunctionsTests(TestCase): City.objects.annotate(kml=functions.AsKML('name')) # Ensuring the KML is as expected. - qs = City.objects.annotate(kml=functions.AsKML('point', precision=9)) - ptown = qs.get(name='Pueblo') + ptown = City.objects.annotate(kml=functions.AsKML('point', precision=9)).get(name='Pueblo') self.assertEqual('-104.609252,38.255001', ptown.kml) - # Same result if the queryset is evaluated again. - self.assertEqual(qs.get(name='Pueblo').kml, ptown.kml) @skipUnlessDBFeature("has_AsSVG_function") def test_assvg(self): diff --git a/tests/gis_tests/geogapp/tests.py b/tests/gis_tests/geogapp/tests.py index 2969ca1cc6..c9986fd78b 100644 --- a/tests/gis_tests/geogapp/tests.py +++ b/tests/gis_tests/geogapp/tests.py @@ -11,7 +11,7 @@ from django.db import connection from django.db.models.functions import Cast from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature -from ..utils import oracle, postgis, spatialite +from ..utils import FuncTestMixin, oracle, postgis, spatialite from .models import City, County, Zipcode @@ -86,7 +86,7 @@ class GeographyTest(TestCase): self.assertEqual(state, c.state) -class GeographyFunctionTests(TestCase): +class GeographyFunctionTests(FuncTestMixin, TestCase): fixtures = ['initial'] @skipUnlessDBFeature("supports_extent_aggr") diff --git a/tests/gis_tests/test_fields.py b/tests/gis_tests/test_fields.py index fb0c953f21..27db3e1dfa 100644 --- a/tests/gis_tests/test_fields.py +++ b/tests/gis_tests/test_fields.py @@ -7,9 +7,9 @@ from django.test import SimpleTestCase class FieldsTests(SimpleTestCase): def test_area_field_deepcopy(self): - field = AreaField() + field = AreaField(None) self.assertEqual(copy.deepcopy(field), field) def test_distance_field_deepcopy(self): - field = DistanceField() + field = DistanceField(None) self.assertEqual(copy.deepcopy(field), field) diff --git a/tests/gis_tests/test_gis_tests_utils.py b/tests/gis_tests/test_gis_tests_utils.py new file mode 100644 index 0000000000..32d072fd9b --- /dev/null +++ b/tests/gis_tests/test_gis_tests_utils.py @@ -0,0 +1,52 @@ +from django.db import connection, models +from django.db.models.expressions import Func +from django.test import SimpleTestCase + +from .utils import FuncTestMixin + + +def test_mutation(raises=True): + def wrapper(mutation_func): + def test(test_case_instance, *args, **kwargs): + class TestFunc(Func): + output_field = models.IntegerField() + + def __init__(self): + self.attribute = 'initial' + super().__init__('initial', ['initial']) + + def as_sql(self, *args, **kwargs): + mutation_func(self) + return '', () + + if raises: + msg = 'TestFunc Func was mutated during compilation.' + with test_case_instance.assertRaisesMessage(AssertionError, msg): + getattr(TestFunc(), 'as_' + connection.vendor)(None, None) + else: + getattr(TestFunc(), 'as_' + connection.vendor)(None, None) + + return test + return wrapper + + +class FuncTestMixinTests(FuncTestMixin, SimpleTestCase): + @test_mutation() + def test_mutated_attribute(func): + func.attribute = 'mutated' + + @test_mutation() + def test_mutated_expressions(func): + func.source_expressions.clear() + + @test_mutation() + def test_mutated_expression(func): + func.source_expressions[0].name = 'mutated' + + @test_mutation() + def test_mutated_expression_deep(func): + func.source_expressions[1].value[0] = 'mutated' + + @test_mutation(raises=False) + def test_not_mutated(func): + pass diff --git a/tests/gis_tests/utils.py b/tests/gis_tests/utils.py index 6eb029c1d5..b30da7e40d 100644 --- a/tests/gis_tests/utils.py +++ b/tests/gis_tests/utils.py @@ -1,8 +1,11 @@ +import copy import unittest from functools import wraps +from unittest import mock from django.conf import settings from django.db import DEFAULT_DB_ALIAS, connection +from django.db.models.expressions import Func def skipUnlessGISLookup(*gis_lookups): @@ -56,3 +59,39 @@ elif spatialite: from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys else: SpatialRefSys = None + + +class FuncTestMixin: + """Assert that Func expressions aren't mutated during their as_sql().""" + def setUp(self): + def as_sql_wrapper(original_as_sql): + def inner(*args, **kwargs): + func = original_as_sql.__self__ + # Resolve output_field before as_sql() so touching it in + # as_sql() won't change __dict__. + func.output_field + __dict__original = copy.deepcopy(func.__dict__) + result = original_as_sql(*args, **kwargs) + msg = '%s Func was mutated during compilation.' % func.__class__.__name__ + self.assertEqual(func.__dict__, __dict__original, msg) + return result + return inner + + def __getattribute__(self, name): + if name != vendor_impl: + return __getattribute__original(self, name) + try: + as_sql = __getattribute__original(self, vendor_impl) + except AttributeError: + as_sql = __getattribute__original(self, 'as_sql') + return as_sql_wrapper(as_sql) + + vendor_impl = 'as_' + connection.vendor + __getattribute__original = Func.__getattribute__ + self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__) + self.func_patcher.start() + super().setUp() + + def tearDown(self): + super().tearDown() + self.func_patcher.stop()