Fixed #28353 -- Fixed some GIS functions when queryset is evaluated more than once.

Reverted test for refs #27603 in favor of using FuncTestMixin.
This commit is contained in:
Sergey Fedoseev 2017-09-11 20:56:39 +05:00 committed by Tim Graham
parent 99e65d6488
commit 3905cfa1a5
12 changed files with 176 additions and 86 deletions

View File

@ -1,5 +1,8 @@
from django.contrib.gis.db.models import GeometryField
from django.contrib.gis.db.models.functions import Distance
from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.utils.functional import cached_property
@ -135,3 +138,24 @@ class BaseSpatialOperations:
'Subclasses of BaseSpatialOperations must provide a '
'get_geometry_converter() method.'
)
def get_area_att_for_field(self, field):
if field.geodetic(self.connection):
if self.connection.features.supports_area_geodetic:
return 'sq_m'
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
else:
units_name = field.units_name(self.connection)
if units_name:
return AreaMeasure.unit_attname(units_name)
def get_distance_att_for_field(self, field):
dist_att = None
if field.geodetic(self.connection):
if self.connection.features.supports_distance_geodetic:
dist_att = 'm'
else:
units = field.units_name(self.connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
return dist_att

View File

@ -212,3 +212,6 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations):
geom.srid = srid
return geom
return converter
def get_area_att_for_field(self, field):
return 'sq_m'

View File

@ -389,3 +389,6 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations):
def converter(value, expression, connection):
return None if value is None else GEOSGeometryBase(read(value), geom_class)
return converter
def get_area_att_for_field(self, field):
return 'sq_m'

View File

@ -3,9 +3,6 @@ from decimal import Decimal
from django.contrib.gis.db.models.fields import BaseSpatialField, GeometryField
from django.contrib.gis.db.models.sql import AreaField, DistanceField
from django.contrib.gis.geometry.backend import Geometry
from django.contrib.gis.measure import (
Area as AreaMeasure, Distance as DistanceMeasure,
)
from django.core.exceptions import FieldError
from django.db.models import (
BooleanField, FloatField, IntegerField, TextField, Transform,
@ -121,29 +118,16 @@ class OracleToleranceMixin:
class Area(OracleToleranceMixin, GeoFunc):
output_field_class = AreaField
arity = 1
def as_sql(self, compiler, connection, **extra_context):
if connection.ops.geography:
self.output_field.area_att = 'sq_m'
else:
# Getting the area units of the geographic field.
if self.geo_field.geodetic(connection):
if connection.features.supports_area_geodetic:
self.output_field.area_att = 'sq_m'
else:
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
else:
units_name = self.geo_field.units_name(connection)
if units_name:
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
return super().as_sql(compiler, connection, **extra_context)
@cached_property
def output_field(self):
return AreaField(self.geo_field)
def as_oracle(self, compiler, connection):
self.output_field = AreaField('sq_m') # Oracle returns area in units of meters.
return super().as_oracle(compiler, connection)
def as_sql(self, compiler, connection, **extra_context):
if not connection.features.supports_area_geodetic and self.geo_field.geodetic(connection):
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
return super().as_sql(compiler, connection, **extra_context)
def as_sqlite(self, compiler, connection, **extra_context):
if self.geo_field.geodetic(connection):
@ -237,27 +221,13 @@ class Difference(OracleToleranceMixin, GeomOutputGeoFunc):
class DistanceResultMixin:
output_field_class = DistanceField
@cached_property
def output_field(self):
return DistanceField(self.geo_field)
def source_is_geography(self):
return self.geo_field.geography and self.geo_field.srid == 4326
def distance_att(self, connection):
dist_att = None
if self.geo_field.geodetic(connection):
if connection.features.supports_distance_geodetic:
dist_att = 'm'
else:
units = self.geo_field.units_name(connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
return dist_att
def as_sql(self, compiler, connection, **extra_context):
clone = self.copy()
clone.output_field.distance_att = self.distance_att(connection)
return super(DistanceResultMixin, clone).as_sql(compiler, connection, **extra_context)
class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
geom_param_pos = (0, 1)
@ -266,19 +236,19 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
def __init__(self, expr1, expr2, spheroid=None, **extra):
expressions = [expr1, expr2]
if spheroid is not None:
self.spheroid = spheroid
expressions += (self._handle_param(spheroid, 'spheroid', bool),)
self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
super().__init__(*expressions, **extra)
def as_postgresql(self, compiler, connection):
clone = self.copy()
function = None
expr2 = self.source_expressions[1]
expr2 = clone.source_expressions[1]
geography = self.source_is_geography()
if expr2.output_field.geography != geography:
if isinstance(expr2, Value):
expr2.output_field.geography = geography
else:
self.source_expressions[1] = Cast(
clone.source_expressions[1] = Cast(
expr2,
GeometryField(srid=expr2.output_field.srid, geography=geography),
)
@ -289,19 +259,12 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
# DistanceSpheroid is more accurate and resource intensive than DistanceSphere
function = connection.ops.spatial_function_name('DistanceSpheroid')
# Replace boolean param by the real spheroid of the base field
self.source_expressions[2] = Value(self.geo_field.spheroid(connection))
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else:
function = connection.ops.spatial_function_name('DistanceSphere')
return super().as_sql(compiler, connection, function=function)
def as_oracle(self, compiler, connection):
if self.spheroid:
self.source_expressions.pop(2)
return super().as_oracle(compiler, connection)
return super(Distance, clone).as_sql(compiler, connection, function=function)
def as_sqlite(self, compiler, connection, **extra_context):
if self.spheroid:
self.source_expressions.pop(2)
if self.geo_field.geodetic(connection):
# SpatiaLite returns NULL instead of zero on geodetic coordinates
extra_context['template'] = 'COALESCE(%(function)s(%(expressions)s, %(spheroid)s), 0)'
@ -360,18 +323,19 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
return super().as_sql(compiler, connection, **extra_context)
def as_postgresql(self, compiler, connection):
clone = self.copy()
function = None
if self.source_is_geography():
self.source_expressions.append(Value(self.spheroid))
clone.source_expressions.append(Value(self.spheroid))
elif self.geo_field.geodetic(connection):
# Geometry fields with geodetic (lon/lat) coordinates need length_spheroid
function = connection.ops.spatial_function_name('LengthSpheroid')
self.source_expressions.append(Value(self.geo_field.spheroid(connection)))
clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
else:
dim = min(f.dim for f in self.get_source_fields() if f)
if dim > 2:
function = connection.ops.length3d
return super().as_sql(compiler, connection, function=function)
return super(Length, clone).as_sql(compiler, connection, function=function)
def as_sqlite(self, compiler, connection):
function = None
@ -482,10 +446,11 @@ class Transform(GeomOutputGeoFunc):
class Translate(Scale):
def as_sqlite(self, compiler, connection):
clone = self.copy()
if len(self.source_expressions) < 4:
# Always provide the z parameter for ST_Translate
self.source_expressions.append(Value(0))
return super().as_sqlite(compiler, connection)
clone.source_expressions.append(Value(0))
return super(Translate, clone).as_sqlite(compiler, connection)
class Union(OracleToleranceMixin, GeomOutputGeoFunc):

View File

@ -10,9 +10,9 @@ from django.db import models
class AreaField(models.FloatField):
"Wrapper for Area values."
def __init__(self, area_att=None):
def __init__(self, geo_field):
super().__init__()
self.area_att = area_att
self.geo_field = geo_field
def get_prep_value(self, value):
if not isinstance(value, Area):
@ -20,19 +20,21 @@ class AreaField(models.FloatField):
return value
def get_db_prep_value(self, value, connection, prepared=False):
if value is None or not self.area_att:
return value
return getattr(value, self.area_att)
if value is None:
return
area_att = connection.ops.get_area_att_for_field(self.geo_field)
return getattr(value, area_att) if area_att else value
def from_db_value(self, value, expression, connection):
if value is None:
return
# If the database returns a Decimal, convert it to a float as expected
# by the Python geometric objects.
if isinstance(value, Decimal):
value = float(value)
# If the units are known, convert value into area measure.
if value is not None and self.area_att:
value = Area(**{self.area_att: value})
return value
area_att = connection.ops.get_area_att_for_field(self.geo_field)
return Area(**{area_att: value}) if area_att else value
def get_internal_type(self):
return 'AreaField'
@ -40,9 +42,9 @@ class AreaField(models.FloatField):
class DistanceField(models.FloatField):
"Wrapper for Distance values."
def __init__(self, distance_att=None):
def __init__(self, geo_field):
super().__init__()
self.distance_att = distance_att
self.geo_field = geo_field
def get_prep_value(self, value):
if isinstance(value, Distance):
@ -52,14 +54,16 @@ class DistanceField(models.FloatField):
def get_db_prep_value(self, value, connection, prepared=False):
if not isinstance(value, Distance):
return value
if not self.distance_att:
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
if not distance_att:
raise ValueError('Distance measure is supplied, but units are unknown for result.')
return getattr(value, self.distance_att)
return getattr(value, distance_att)
def from_db_value(self, value, expression, connection):
if value is None or not self.distance_att:
return value
return Distance(**{self.distance_att: value})
if value is None:
return
distance_att = connection.ops.get_distance_att_for_field(self.geo_field)
return Distance(**{distance_att: value}) if distance_att else value
def get_internal_type(self):
return 'DistanceField'

View File

@ -9,7 +9,9 @@ from django.db import connection
from django.db.models import F, Q
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from ..utils import mysql, no_oracle, oracle, postgis, spatialite
from ..utils import (
FuncTestMixin, mysql, no_oracle, oracle, postgis, spatialite,
)
from .models import (
AustraliaCity, CensusZipcode, Interstate, SouthTexasCity, SouthTexasCityFt,
SouthTexasInterstate, SouthTexasZipcode,
@ -262,7 +264,7 @@ Perimeter(geom1) | OK | :-(
''' # NOQA
class DistanceFunctionsTests(TestCase):
class DistanceFunctionsTests(FuncTestMixin, TestCase):
fixtures = ['initial']
@skipUnlessDBFeature("has_Area_function")

View File

@ -8,6 +8,7 @@ from django.contrib.gis.db.models.functions import (
from django.contrib.gis.geos import GEOSGeometry, LineString, Point, Polygon
from django.test import TestCase, skipUnlessDBFeature
from ..utils import FuncTestMixin
from .models import (
City3D, Interstate2D, Interstate3D, InterstateProj2D, InterstateProj3D,
MultiPoint3D, Point2D, Point3D, Polygon2D, Polygon3D,
@ -205,7 +206,7 @@ class Geo3DTest(Geo3DLoadingHelper, TestCase):
@skipUnlessDBFeature("supports_3d_functions")
class Geo3DFunctionsTests(Geo3DLoadingHelper, TestCase):
class Geo3DFunctionsTests(FuncTestMixin, Geo3DLoadingHelper, TestCase):
def test_kml(self):
"""
Test KML() function with Z values.

View File

@ -12,11 +12,11 @@ from django.db import connection
from django.db.models import Sum
from django.test import TestCase, skipUnlessDBFeature
from ..utils import mysql, oracle, postgis, spatialite
from ..utils import FuncTestMixin, mysql, oracle, postgis, spatialite
from .models import City, Country, CountryWebMercator, State, Track
class GISFunctionsTests(TestCase):
class GISFunctionsTests(FuncTestMixin, TestCase):
"""
Testing functions from django/contrib/gis/db/models/functions.py.
Area/Distance/Length/Perimeter are tested in distapp/tests.
@ -127,11 +127,8 @@ class GISFunctionsTests(TestCase):
City.objects.annotate(kml=functions.AsKML('name'))
# Ensuring the KML is as expected.
qs = City.objects.annotate(kml=functions.AsKML('point', precision=9))
ptown = qs.get(name='Pueblo')
ptown = City.objects.annotate(kml=functions.AsKML('point', precision=9)).get(name='Pueblo')
self.assertEqual('<Point><coordinates>-104.609252,38.255001</coordinates></Point>', ptown.kml)
# Same result if the queryset is evaluated again.
self.assertEqual(qs.get(name='Pueblo').kml, ptown.kml)
@skipUnlessDBFeature("has_AsSVG_function")
def test_assvg(self):

View File

@ -11,7 +11,7 @@ from django.db import connection
from django.db.models.functions import Cast
from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature
from ..utils import oracle, postgis, spatialite
from ..utils import FuncTestMixin, oracle, postgis, spatialite
from .models import City, County, Zipcode
@ -86,7 +86,7 @@ class GeographyTest(TestCase):
self.assertEqual(state, c.state)
class GeographyFunctionTests(TestCase):
class GeographyFunctionTests(FuncTestMixin, TestCase):
fixtures = ['initial']
@skipUnlessDBFeature("supports_extent_aggr")

View File

@ -7,9 +7,9 @@ from django.test import SimpleTestCase
class FieldsTests(SimpleTestCase):
def test_area_field_deepcopy(self):
field = AreaField()
field = AreaField(None)
self.assertEqual(copy.deepcopy(field), field)
def test_distance_field_deepcopy(self):
field = DistanceField()
field = DistanceField(None)
self.assertEqual(copy.deepcopy(field), field)

View File

@ -0,0 +1,52 @@
from django.db import connection, models
from django.db.models.expressions import Func
from django.test import SimpleTestCase
from .utils import FuncTestMixin
def test_mutation(raises=True):
def wrapper(mutation_func):
def test(test_case_instance, *args, **kwargs):
class TestFunc(Func):
output_field = models.IntegerField()
def __init__(self):
self.attribute = 'initial'
super().__init__('initial', ['initial'])
def as_sql(self, *args, **kwargs):
mutation_func(self)
return '', ()
if raises:
msg = 'TestFunc Func was mutated during compilation.'
with test_case_instance.assertRaisesMessage(AssertionError, msg):
getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
else:
getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
return test
return wrapper
class FuncTestMixinTests(FuncTestMixin, SimpleTestCase):
@test_mutation()
def test_mutated_attribute(func):
func.attribute = 'mutated'
@test_mutation()
def test_mutated_expressions(func):
func.source_expressions.clear()
@test_mutation()
def test_mutated_expression(func):
func.source_expressions[0].name = 'mutated'
@test_mutation()
def test_mutated_expression_deep(func):
func.source_expressions[1].value[0] = 'mutated'
@test_mutation(raises=False)
def test_not_mutated(func):
pass

View File

@ -1,8 +1,11 @@
import copy
import unittest
from functools import wraps
from unittest import mock
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, connection
from django.db.models.expressions import Func
def skipUnlessGISLookup(*gis_lookups):
@ -56,3 +59,39 @@ elif spatialite:
from django.contrib.gis.db.backends.spatialite.models import SpatialiteSpatialRefSys as SpatialRefSys
else:
SpatialRefSys = None
class FuncTestMixin:
"""Assert that Func expressions aren't mutated during their as_sql()."""
def setUp(self):
def as_sql_wrapper(original_as_sql):
def inner(*args, **kwargs):
func = original_as_sql.__self__
# Resolve output_field before as_sql() so touching it in
# as_sql() won't change __dict__.
func.output_field
__dict__original = copy.deepcopy(func.__dict__)
result = original_as_sql(*args, **kwargs)
msg = '%s Func was mutated during compilation.' % func.__class__.__name__
self.assertEqual(func.__dict__, __dict__original, msg)
return result
return inner
def __getattribute__(self, name):
if name != vendor_impl:
return __getattribute__original(self, name)
try:
as_sql = __getattribute__original(self, vendor_impl)
except AttributeError:
as_sql = __getattribute__original(self, 'as_sql')
return as_sql_wrapper(as_sql)
vendor_impl = 'as_' + connection.vendor
__getattribute__original = Func.__getattribute__
self.func_patcher = mock.patch.object(Func, '__getattribute__', __getattribute__)
self.func_patcher.start()
super().setUp()
def tearDown(self):
super().tearDown()
self.func_patcher.stop()