Fixed #26112 -- Error when computing aggregate of GIS areas.
Thanks Simon Charette and Claude Paroz for the reviews.
This commit is contained in:
parent
16baec5c8a
commit
a08d2463d2
|
@ -117,24 +117,23 @@ class OracleToleranceMixin(object):
|
|||
|
||||
|
||||
class Area(OracleToleranceMixin, GeoFunc):
|
||||
output_field_class = AreaField
|
||||
arity = 1
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if connection.ops.geography:
|
||||
# Geography fields support area calculation, returns square meters.
|
||||
self.output_field = AreaField('sq_m')
|
||||
elif not self.output_field.geodetic(connection):
|
||||
self.output_field.area_att = 'sq_m'
|
||||
else:
|
||||
# Getting the area units of the geographic field.
|
||||
units = self.output_field.units_name(connection)
|
||||
if units:
|
||||
self.output_field = AreaField(
|
||||
AreaMeasure.unit_attname(self.output_field.units_name(connection))
|
||||
)
|
||||
else:
|
||||
self.output_field = FloatField()
|
||||
else:
|
||||
source_fields = self.get_source_fields()
|
||||
if len(source_fields):
|
||||
source_field = source_fields[0]
|
||||
if source_field.geodetic(connection):
|
||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
units_name = source_field.units_name(connection)
|
||||
if units_name:
|
||||
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
||||
return super(Area, self).as_sql(compiler, connection)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
|
|
|
@ -21,13 +21,14 @@ class BaseField(object):
|
|||
|
||||
class AreaField(BaseField):
|
||||
"Wrapper for Area values."
|
||||
def __init__(self, area_att):
|
||||
def __init__(self, area_att=None):
|
||||
self.area_att = area_att
|
||||
|
||||
def from_db_value(self, value, expression, connection, context):
|
||||
if connection.features.interprets_empty_strings_as_nulls and value == '':
|
||||
value = None
|
||||
if value is not None:
|
||||
# If the units are known, convert value into area measure.
|
||||
if value is not None and self.area_att:
|
||||
value = Area(**{self.area_att: value})
|
||||
return value
|
||||
|
||||
|
|
|
@ -22,6 +22,10 @@ class Country(NamedModel):
|
|||
mpoly = models.MultiPolygonField() # SRID, by default, is 4326
|
||||
|
||||
|
||||
class CountryWebMercator(NamedModel):
|
||||
mpoly = models.MultiPolygonField(srid=3857)
|
||||
|
||||
|
||||
class City(NamedModel):
|
||||
point = models.PointField()
|
||||
|
||||
|
|
|
@ -5,12 +5,14 @@ from decimal import Decimal
|
|||
|
||||
from django.contrib.gis.db.models import functions
|
||||
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
||||
from django.contrib.gis.measure import Area
|
||||
from django.db import connection
|
||||
from django.db.models import Sum
|
||||
from django.test import TestCase, skipUnlessDBFeature
|
||||
from django.utils import six
|
||||
|
||||
from ..utils import mysql, oracle, postgis, spatialite
|
||||
from .models import City, Country, State, Track
|
||||
from .models import City, Country, CountryWebMercator, State, Track
|
||||
|
||||
|
||||
@skipUnlessDBFeature("gis_enabled")
|
||||
|
@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase):
|
|||
expected = c.mpoly.intersection(geom)
|
||||
self.assertEqual(c.inter, expected)
|
||||
|
||||
@skipUnlessDBFeature("has_Area_function")
|
||||
def test_area_with_regular_aggregate(self):
|
||||
# Create projected country objects, for this test to work on all backends.
|
||||
for c in Country.objects.all():
|
||||
CountryWebMercator.objects.create(name=c.name, mpoly=c.mpoly)
|
||||
# Test in projected coordinate system
|
||||
qs = CountryWebMercator.objects.annotate(area_sum=Sum(functions.Area('mpoly')))
|
||||
for c in qs:
|
||||
result = c.area_sum
|
||||
# If the result is a measure object, get value.
|
||||
if isinstance(result, Area):
|
||||
result = result.sq_m
|
||||
self.assertAlmostEqual((result - c.mpoly.area) / c.mpoly.area, 0)
|
||||
|
||||
@skipUnlessDBFeature("has_MemSize_function")
|
||||
def test_memsize(self):
|
||||
ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')
|
||||
|
|
Loading…
Reference in New Issue