Fixed #26112 -- Error when computing aggregate of GIS areas.

Thanks Simon Charette and Claude Paroz for the reviews.
This commit is contained in:
Daniel Wiesmann 2016-01-22 11:03:05 +00:00 committed by Claude Paroz
parent 16baec5c8a
commit a08d2463d2
4 changed files with 36 additions and 16 deletions

View File

@ -117,24 +117,23 @@ class OracleToleranceMixin(object):
class Area(OracleToleranceMixin, GeoFunc): class Area(OracleToleranceMixin, GeoFunc):
output_field_class = AreaField
arity = 1 arity = 1
def as_sql(self, compiler, connection): def as_sql(self, compiler, connection):
if connection.ops.geography: if connection.ops.geography:
# Geography fields support area calculation, returns square meters. self.output_field.area_att = 'sq_m'
self.output_field = AreaField('sq_m')
elif not self.output_field.geodetic(connection):
# Getting the area units of the geographic field.
units = self.output_field.units_name(connection)
if units:
self.output_field = AreaField(
AreaMeasure.unit_attname(self.output_field.units_name(connection))
)
else:
self.output_field = FloatField()
else: else:
# TODO: Do we want to support raw number areas for geodetic fields? # Getting the area units of the geographic field.
raise NotImplementedError('Area on geodetic coordinate systems not supported.') source_fields = self.get_source_fields()
if len(source_fields):
source_field = source_fields[0]
if source_field.geodetic(connection):
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
units_name = source_field.units_name(connection)
if units_name:
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
return super(Area, self).as_sql(compiler, connection) return super(Area, self).as_sql(compiler, connection)
def as_oracle(self, compiler, connection): def as_oracle(self, compiler, connection):

View File

@ -21,13 +21,14 @@ class BaseField(object):
class AreaField(BaseField): class AreaField(BaseField):
"Wrapper for Area values." "Wrapper for Area values."
def __init__(self, area_att): def __init__(self, area_att=None):
self.area_att = area_att self.area_att = area_att
def from_db_value(self, value, expression, connection, context): def from_db_value(self, value, expression, connection, context):
if connection.features.interprets_empty_strings_as_nulls and value == '': if connection.features.interprets_empty_strings_as_nulls and value == '':
value = None value = None
if value is not None: # If the units are known, convert value into area measure.
if value is not None and self.area_att:
value = Area(**{self.area_att: value}) value = Area(**{self.area_att: value})
return value return value

View File

@ -22,6 +22,10 @@ class Country(NamedModel):
mpoly = models.MultiPolygonField() # SRID, by default, is 4326 mpoly = models.MultiPolygonField() # SRID, by default, is 4326
class CountryWebMercator(NamedModel):
mpoly = models.MultiPolygonField(srid=3857)
class City(NamedModel): class City(NamedModel):
point = models.PointField() point = models.PointField()

View File

@ -5,12 +5,14 @@ from decimal import Decimal
from django.contrib.gis.db.models import functions from django.contrib.gis.db.models import functions
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
from django.contrib.gis.measure import Area
from django.db import connection from django.db import connection
from django.db.models import Sum
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils import six from django.utils import six
from ..utils import mysql, oracle, postgis, spatialite from ..utils import mysql, oracle, postgis, spatialite
from .models import City, Country, State, Track from .models import City, Country, CountryWebMercator, State, Track
@skipUnlessDBFeature("gis_enabled") @skipUnlessDBFeature("gis_enabled")
@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase):
expected = c.mpoly.intersection(geom) expected = c.mpoly.intersection(geom)
self.assertEqual(c.inter, expected) self.assertEqual(c.inter, expected)
@skipUnlessDBFeature("has_Area_function")
def test_area_with_regular_aggregate(self):
# Create projected country objects, for this test to work on all backends.
for c in Country.objects.all():
CountryWebMercator.objects.create(name=c.name, mpoly=c.mpoly)
# Test in projected coordinate system
qs = CountryWebMercator.objects.annotate(area_sum=Sum(functions.Area('mpoly')))
for c in qs:
result = c.area_sum
# If the result is a measure object, get value.
if isinstance(result, Area):
result = result.sq_m
self.assertAlmostEqual((result - c.mpoly.area) / c.mpoly.area, 0)
@skipUnlessDBFeature("has_MemSize_function") @skipUnlessDBFeature("has_MemSize_function")
def test_memsize(self): def test_memsize(self):
ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo') ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')