From a08d2463d2674b95f5a995f77cd9596168378a4f Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Fri, 22 Jan 2016 11:03:05 +0000 Subject: [PATCH] Fixed #26112 -- Error when computing aggregate of GIS areas. Thanks Simon Charette and Claude Paroz for the reviews. --- django/contrib/gis/db/models/functions.py | 25 +++++++++---------- .../contrib/gis/db/models/sql/conversion.py | 5 ++-- tests/gis_tests/geoapp/models.py | 4 +++ tests/gis_tests/geoapp/test_functions.py | 18 ++++++++++++- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 17847bdc5c..4f76c155f2 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -117,24 +117,23 @@ class OracleToleranceMixin(object): class Area(OracleToleranceMixin, GeoFunc): + output_field_class = AreaField arity = 1 def as_sql(self, compiler, connection): if connection.ops.geography: - # Geography fields support area calculation, returns square meters. - self.output_field = AreaField('sq_m') - elif not self.output_field.geodetic(connection): - # Getting the area units of the geographic field. - units = self.output_field.units_name(connection) - if units: - self.output_field = AreaField( - AreaMeasure.unit_attname(self.output_field.units_name(connection)) - ) - else: - self.output_field = FloatField() + self.output_field.area_att = 'sq_m' else: - # TODO: Do we want to support raw number areas for geodetic fields? - raise NotImplementedError('Area on geodetic coordinate systems not supported.') + # Getting the area units of the geographic field. + source_fields = self.get_source_fields() + if len(source_fields): + source_field = source_fields[0] + if source_field.geodetic(connection): + # TODO: Do we want to support raw number areas for geodetic fields? + raise NotImplementedError('Area on geodetic coordinate systems not supported.') + units_name = source_field.units_name(connection) + if units_name: + self.output_field.area_att = AreaMeasure.unit_attname(units_name) return super(Area, self).as_sql(compiler, connection) def as_oracle(self, compiler, connection): diff --git a/django/contrib/gis/db/models/sql/conversion.py b/django/contrib/gis/db/models/sql/conversion.py index dbbdb8b338..afc1053790 100644 --- a/django/contrib/gis/db/models/sql/conversion.py +++ b/django/contrib/gis/db/models/sql/conversion.py @@ -21,13 +21,14 @@ class BaseField(object): class AreaField(BaseField): "Wrapper for Area values." - def __init__(self, area_att): + def __init__(self, area_att=None): self.area_att = area_att def from_db_value(self, value, expression, connection, context): if connection.features.interprets_empty_strings_as_nulls and value == '': value = None - if value is not None: + # If the units are known, convert value into area measure. + if value is not None and self.area_att: value = Area(**{self.area_att: value}) return value diff --git a/tests/gis_tests/geoapp/models.py b/tests/gis_tests/geoapp/models.py index d6ca4f5010..42a93356b6 100644 --- a/tests/gis_tests/geoapp/models.py +++ b/tests/gis_tests/geoapp/models.py @@ -22,6 +22,10 @@ class Country(NamedModel): mpoly = models.MultiPolygonField() # SRID, by default, is 4326 +class CountryWebMercator(NamedModel): + mpoly = models.MultiPolygonField(srid=3857) + + class City(NamedModel): point = models.PointField() diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py index 663386a3b2..fb88513eab 100644 --- a/tests/gis_tests/geoapp/test_functions.py +++ b/tests/gis_tests/geoapp/test_functions.py @@ -5,12 +5,14 @@ from decimal import Decimal from django.contrib.gis.db.models import functions from django.contrib.gis.geos import LineString, Point, Polygon, fromstr +from django.contrib.gis.measure import Area from django.db import connection +from django.db.models import Sum from django.test import TestCase, skipUnlessDBFeature from django.utils import six from ..utils import mysql, oracle, postgis, spatialite -from .models import City, Country, State, Track +from .models import City, Country, CountryWebMercator, State, Track @skipUnlessDBFeature("gis_enabled") @@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase): expected = c.mpoly.intersection(geom) self.assertEqual(c.inter, expected) + @skipUnlessDBFeature("has_Area_function") + def test_area_with_regular_aggregate(self): + # Create projected country objects, for this test to work on all backends. + for c in Country.objects.all(): + CountryWebMercator.objects.create(name=c.name, mpoly=c.mpoly) + # Test in projected coordinate system + qs = CountryWebMercator.objects.annotate(area_sum=Sum(functions.Area('mpoly'))) + for c in qs: + result = c.area_sum + # If the result is a measure object, get value. + if isinstance(result, Area): + result = result.sq_m + self.assertAlmostEqual((result - c.mpoly.area) / c.mpoly.area, 0) + @skipUnlessDBFeature("has_MemSize_function") def test_memsize(self): ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')