mirror of https://github.com/django/django.git
Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.
Reverted test for refs #27603 in favor of using FuncTestMixin.
This commit is contained in:
parent
99e65d6488
commit
3905cfa1a5
|
@ -1,5 +1,8 @@
|
|||
from django.contrib.gis.db.models import GeometryField
|
||||
from django.contrib.gis.db.models.functions import Distance
|
||||
from django.contrib.gis.measure import (
|
||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
|
@ -135,3 +138,24 @@ class BaseSpatialOperations:
|
|||
'Subclasses of BaseSpatialOperations must provide a '
|
||||
'get_geometry_converter() method.'
|
||||
)
|
||||
|
||||
def get_area_att_for_field(self, field):
|
||||
if field.geodetic(self.connection):
|
||||
if self.connection.features.supports_area_geodetic:
|
||||
return 'sq_m'
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
else:
|
||||
units_name = field.units_name(self.connection)
|
||||
if units_name:
|
||||
return AreaMeasure.unit_attname(units_name)
|
||||
|
||||
def get_distance_att_for_field(self, field):
|
||||
dist_att = None
|
||||
if field.geodetic(self.connection):
|
||||
if self.connection.features.supports_distance_geodetic:
|
||||
dist_att = 'm'
|
||||
else:
|
||||
units = field.units_name(self.connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
return dist_att
|
||||
|
|
|
@ -212,3 +212,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
geom.srid = srid
|
||||
return geom
|
||||
return converter
|
||||
|
||||
def get_area_att_for_field(self, field):
|
||||
return 'sq_m'
|
||||
|
|
|
@ -389,3 +389,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
def converter(value, expression, connection):
|
||||
return None if value is None else GEOSGeometryBase(read(value), geom_class)
|
||||
return converter
|
||||
|
||||
def get_area_att_for_field(self, field):
|
||||
return 'sq_m'
|
||||
|
|
|
@ -3,9 +3,6 @@ from decimal import Decimal
|
|||
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
|
||||
from django.contrib.gis.db.models.sql import AreaField, DistanceField
|
||||
from django.contrib.gis.geometry.backend import Geometry
|
||||
from django.contrib.gis.measure import (
|
||||
Area as AreaMeasure, Distance as DistanceMeasure,
|
||||
)
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db.models import (
|
||||
BooleanField, FloatField, IntegerField, TextField, Transform,
|
||||
|
@ -121,29 +118,16 @@ class OracleToleranceMixin:
|
|||
|
||||
|
||||
class Area(OracleToleranceMixin, GeoFunc):
|
||||
output_field_class = AreaField
|
||||
arity = 1
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if connection.ops.geography:
|
||||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# Getting the area units of the geographic field.
|
||||
if self.geo_field.geodetic(connection):
|
||||
if connection.features.supports_area_geodetic:
|
||||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
else:
|
||||
units_name = self.geo_field.units_name(connection)
|
||||
if units_name:
|
||||
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return AreaField(self.geo_field)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
|
||||
return super().as_oracle(compiler, connection)
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if self.geo_field.geodetic(connection):
|
||||
|
@ -237,27 +221,13 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
|
|||
|
||||
|
||||
class DistanceResultMixin:
|
||||
output_field_class = DistanceField
|
||||
@cached_property
|
||||
def output_field(self):
|
||||
return DistanceField(self.geo_field)
|
||||
|
||||
def source_is_geography(self):
|
||||
return self.geo_field.geography and self.geo_field.srid == 4326
|
||||
|
||||
def distance_att(self, connection):
|
||||
dist_att = None
|
||||
if self.geo_field.geodetic(connection):
|
||||
if connection.features.supports_distance_geodetic:
|
||||
dist_att = 'm'
|
||||
else:
|
||||
units = self.geo_field.units_name(connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
return dist_att
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
clone = self.copy()
|
||||
clone.output_field.distance_att = self.distance_att(connection)
|
||||
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
||||
geom_param_pos = (0, 1)
|
||||
|
@ -266,19 +236,19 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
def __init__(self, expr1, expr2, spheroid=None, **extra):
|
||||
expressions = [expr1, expr2]
|
||||
if spheroid is not None:
|
||||
self.spheroid = spheroid
|
||||
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
|
||||
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
clone = self.copy()
|
||||
function = None
|
||||
expr2 = self.source_expressions[1]
|
||||
expr2 = clone.source_expressions[1]
|
||||
geography = self.source_is_geography()
|
||||
if expr2.output_field.geography != geography:
|
||||
if isinstance(expr2, Value):
|
||||
expr2.output_field.geography = geography
|
||||
else:
|
||||
self.source_expressions[1] = Cast(
|
||||
clone.source_expressions[1] = Cast(
|
||||
expr2,
|
||||
GeometryField(srid=expr2.output_field.srid, geography=geography),
|
||||
)
|
||||
|
@ -289,19 +259,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
|
||||
function = connection.ops.spatial_function_name('DistanceSpheroid')
|
||||
# Replace boolean param by the real spheroid of the base field
|
||||
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
|
||||
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
else:
|
||||
function = connection.ops.spatial_function_name('DistanceSphere')
|
||||
return super().as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
if self.spheroid:
|
||||
self.source_expressions.pop(2)
|
||||
return super().as_oracle(compiler, connection)
|
||||
return super(Distance, clone).as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
if self.spheroid:
|
||||
self.source_expressions.pop(2)
|
||||
if self.geo_field.geodetic(connection):
|
||||
# SpatiaLite returns NULL instead of zero on geodetic coordinates
|
||||
extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
|
||||
|
@ -360,18 +323,19 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
|
|||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
clone = self.copy()
|
||||
function = None
|
||||
if self.source_is_geography():
|
||||
self.source_expressions.append(Value(self.spheroid))
|
||||
clone.source_expressions.append(Value(self.spheroid))
|
||||
elif self.geo_field.geodetic(connection):
|
||||
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
|
||||
function = connection.ops.spatial_function_name('LengthSpheroid')
|
||||
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
|
||||
else:
|
||||
dim = min(f.dim for f in self.get_source_fields() if f)
|
||||
if dim > 2:
|
||||
function = connection.ops.length3d
|
||||
return super().as_sql(compiler, connection, function=function)
|
||||
return super(Length, clone).as_sql(compiler, connection, function=function)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
function = None
|
||||
|
@ -482,10 +446,11 @@ class Transform(GeomOutputGeoFunc):
|
|||
|
||||
class Translate(Scale):
|
||||
def as_sqlite(self, compiler, connection):
|
||||
clone = self.copy()
|
||||
if len(self.source_expressions) < 4:
|
||||
# Always provide the z parameter for ST_Translate
|
||||
self.source_expressions.append(Value(0))
|
||||
return super().as_sqlite(compiler, connection)
|
||||
clone.source_expressions.append(Value(0))
|
||||
return super(Translate, clone).as_sqlite(compiler, connection)
|
||||
|
||||
|
||||
class Union(OracleToleranceMixin, GeomOutputGeoFunc):
|
||||
|
|
|
@ -10,9 +10,9 @@ from django.db import models
|
|||
|
||||
class AreaField(models.FloatField):
|
||||
"Wrapper for Area values."
|
||||
def __init__(self, area_att=None):
|
||||
def __init__(self, geo_field):
|
||||
super().__init__()
|
||||
self.area_att = area_att
|
||||
self.geo_field = geo_field
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if not isinstance(value, Area):
|
||||
|
@ -20,19 +20,21 @@ class AreaField(models.FloatField):
|
|||
return value
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if value is None or not self.area_att:
|
||||
return value
|
||||
return getattr(value, self.area_att)
|
||||
if value is None:
|
||||
return
|
||||
area_att = connection.ops.get_area_att_for_field(self.geo_field)
|
||||
return getattr(value, area_att) if area_att else value
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return
|
||||
# If the database returns a Decimal, convert it to a float as expected
|
||||
# by the Python geometric objects.
|
||||
if isinstance(value, Decimal):
|
||||
value = float(value)
|
||||
# If the units are known, convert value into area measure.
|
||||
if value is not None and self.area_att:
|
||||
value = Area(**{self.area_att: value})
|
||||
return value
|
||||
area_att = connection.ops.get_area_att_for_field(self.geo_field)
|
||||
return Area(**{area_att: value}) if area_att else value
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'AreaField'
|
||||
|
@ -40,9 +42,9 @@ class AreaField(models.FloatField):
|
|||
|
||||
class DistanceField(models.FloatField):
|
||||
"Wrapper for Distance values."
|
||||
def __init__(self, distance_att=None):
|
||||
def __init__(self, geo_field):
|
||||
super().__init__()
|
||||
self.distance_att = distance_att
|
||||
self.geo_field = geo_field
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if isinstance(value, Distance):
|
||||
|
@ -52,14 +54,16 @@ class DistanceField(models.FloatField):
|
|||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not isinstance(value, Distance):
|
||||
return value
|
||||
if not self.distance_att:
|
||||
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
|
||||
if not distance_att:
|
||||
raise ValueError('Distance measure is supplied, but units are unknown for result.')
|
||||
return getattr(value, self.distance_att)
|
||||
return getattr(value, distance_att)
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None or not self.distance_att:
|
||||
return value
|
||||
return Distance(**{self.distance_att: value})
|
||||
if value is None:
|
||||
return
|
||||
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
|
||||
return Distance(**{distance_att: value}) if distance_att else value
|
||||
|
||||
def get_internal_type(self):
|
||||
return 'DistanceField'
|
||||
|
|
|
@ -9,7 +9,9 @@ from django.db import connection
|
|||
from django.db.models import F, Q
|
||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
|
||||
from ..utils import mysql, no_oracle, oracle, postgis, spatialite
|
||||
from ..utils import (
|
||||
FuncTestMixin, mysql, no_oracle, oracle, postgis, spatialite,
|
||||
)
|
||||
from .models import (
|
||||
AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt,
|
||||
SouthTexasInterstate, SouthTexasZipcode,
|
||||
|
@ -262,7 +264,7 @@ Perimeter(geom1) | OK | :-(
|
|||
''' # NOQA
|
||||
|
||||
|
||||
class DistanceFunctionsTests(TestCase):
|
||||
class DistanceFunctionsTests(FuncTestMixin, TestCase):
|
||||
fixtures = ['initial']
|
||||
|
||||
@skipUnlessDBFeature("has_Area_function")
|
||||
|
|
|
@ -8,6 +8,7 @@ from django.contrib.gis.db.models.functions import (
|
|||
from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
|
||||
from ..utils import FuncTestMixin
|
||||
from .models import (
|
||||
City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D,
|
||||
MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D,
|
||||
|
@ -205,7 +206,7 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase):
|
|||
|
||||
|
||||
@skipUnlessDBFeature("supports_3d_functions")
|
||||
class Geo3DFunctionsTests(Geo3DLoadingHelper, TestCase):
|
||||
class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase):
|
||||
def test_kml(self):
|
||||
"""
|
||||
Test KML() function with Z values.
|
||||
|
|
|
@ -12,11 +12,11 @@ from django.db import connection
|
|||
from django.db.models import Sum
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
|
||||
from ..utils import mysql, oracle, postgis, spatialite
|
||||
from ..utils import FuncTestMixin, mysql, oracle, postgis, spatialite
|
||||
from .models import City, Country, CountryWebMercator, State, Track
|
||||
|
||||
|
||||
class GISFunctionsTests(TestCase):
|
||||
class GISFunctionsTests(FuncTestMixin, TestCase):
|
||||
"""
|
||||
Testing functions from django/contrib/gis/db/models/functions.py.
|
||||
Area/Distance/Length/Perimeter are tested in distapp/tests.
|
||||
|
@ -127,11 +127,8 @@ class GISFunctionsTests(TestCase):
|
|||
City.objects.annotate(kml=functions.AsKML('name'))
|
||||
|
||||
# Ensuring the KML is as expected.
|
||||
qs = City.objects.annotate(kml=functions.AsKML('point', precision=9))
|
||||
ptown = qs.get(name='Pueblo')
|
||||
ptown = City.objects.annotate(kml=functions.AsKML('point', precision=9)).get(name='Pueblo')
|
||||
self.assertEqual('<Point><coordinates>-104.609252,38.255001</coordinates></Point>', ptown.kml)
|
||||
# Same result if the queryset is evaluated again.
|
||||
self.assertEqual(qs.get(name='Pueblo').kml, ptown.kml)
|
||||
|
||||
@skipUnlessDBFeature("has_AsSVG_function")
|
||||
def test_assvg(self):
|
||||
|
|
|
@ -11,7 +11,7 @@ from django.db import connection
|
|||
from django.db.models.functions import Cast
|
||||
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
|
||||
|
||||
from ..utils import oracle, postgis, spatialite
|
||||
from ..utils import FuncTestMixin, oracle, postgis, spatialite
|
||||
from .models import City, County, Zipcode
|
||||
|
||||
|
||||
|
@ -86,7 +86,7 @@ class GeographyTest(TestCase):
|
|||
self.assertEqual(state, c.state)
|
||||
|
||||
|
||||
class GeographyFunctionTests(TestCase):
|
||||
class GeographyFunctionTests(FuncTestMixin, TestCase):
|
||||
fixtures = ['initial']
|
||||
|
||||
@skipUnlessDBFeature("supports_extent_aggr")
|
||||
|
|
|
@ -7,9 +7,9 @@ from django.test import SimpleTestCase
|
|||
class FieldsTests(SimpleTestCase):
|
||||
|
||||
def test_area_field_deepcopy(self):
|
||||
field = AreaField()
|
||||
field = AreaField(None)
|
||||
self.assertEqual(copy.deepcopy(field), field)
|
||||
|
||||
def test_distance_field_deepcopy(self):
|
||||
field = DistanceField()
|
||||
field = DistanceField(None)
|
||||
self.assertEqual(copy.deepcopy(field), field)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
from django.db import connection, models
|
||||
from django.db.models.expressions import Func
|
||||
from django.test import SimpleTestCase
|
||||
|
||||
from .utils import FuncTestMixin
|
||||
|
||||
|
||||
def test_mutation(raises=True):
|
||||
def wrapper(mutation_func):
|
||||
def test(test_case_instance, *args, **kwargs):
|
||||
class TestFunc(Func):
|
||||
output_field = models.IntegerField()
|
||||
|
||||
def __init__(self):
|
||||
self.attribute = 'initial'
|
||||
super().__init__('initial', ['initial'])
|
||||
|
||||
def as_sql(self, *args, **kwargs):
|
||||
mutation_func(self)
|
||||
return '', ()
|
||||
|
||||
if raises:
|
||||
msg = 'TestFunc Func was mutated during compilation.'
|
||||
with test_case_instance.assertRaisesMessage(AssertionError, msg):
|
||||
getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
|
||||
else:
|
||||
getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
|
||||
|
||||
return test
|
||||
return wrapper
|
||||
|
||||
|
||||
class FuncTestMixinTests(FuncTestMixin, SimpleTestCase):
|
||||
@test_mutation()
|
||||
def test_mutated_attribute(func):
|
||||
func.attribute = 'mutated'
|
||||
|
||||
@test_mutation()
|
||||
def test_mutated_expressions(func):
|
||||
func.source_expressions.clear()
|
||||
|
||||
@test_mutation()
|
||||
def test_mutated_expression(func):
|
||||
func.source_expressions[0].name = 'mutated'
|
||||
|
||||
@test_mutation()
|
||||
def test_mutated_expression_deep(func):
|
||||
func.source_expressions[1].value[0] = 'mutated'
|
||||
|
||||
@test_mutation(raises=False)
|
||||
def test_not_mutated(func):
|
||||
pass
|
|
@ -1,8 +1,11 @@
|
|||
import copy
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DEFAULT_DB_ALIAS, connection
|
||||
from django.db.models.expressions import Func
|
||||
|
||||
|
||||
def skipUnlessGISLookup(*gis_lookups):
|
||||
|
@ -56,3 +59,39 @@ elif spatialite:
|
|||
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys
|
||||
else:
|
||||
SpatialRefSys = None
|
||||
|
||||
|
||||
class FuncTestMixin:
|
||||
"""Assert that Func expressions aren't mutated during their as_sql()."""
|
||||
def setUp(self):
|
||||
def as_sql_wrapper(original_as_sql):
|
||||
def inner(*args, **kwargs):
|
||||
func = original_as_sql.__self__
|
||||
# Resolve output_field before as_sql() so touching it in
|
||||
# as_sql() won't change __dict__.
|
||||
func.output_field
|
||||
__dict__original = copy.deepcopy(func.__dict__)
|
||||
result = original_as_sql(*args, **kwargs)
|
||||
msg = '%s Func was mutated during compilation.' % func.__class__.__name__
|
||||
self.assertEqual(func.__dict__, __dict__original, msg)
|
||||
return result
|
||||
return inner
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name != vendor_impl:
|
||||
return __getattribute__original(self, name)
|
||||
try:
|
||||
as_sql = __getattribute__original(self, vendor_impl)
|
||||
except AttributeError:
|
||||
as_sql = __getattribute__original(self, 'as_sql')
|
||||
return as_sql_wrapper(as_sql)
|
||||
|
||||
vendor_impl = 'as_' + connection.vendor
|
||||
__getattribute__original = Func.__getattribute__
|
||||
self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__)
|
||||
self.func_patcher.start()
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.func_patcher.stop()
|
||||
|
|
Loading…
Reference in New Issue