Fixed #26112 -- Error when computing aggregate of GIS areas.
Thanks Simon Charette and Claude Paroz for the reviews.
This commit is contained in:
parent
16baec5c8a
commit
a08d2463d2
|
@ -117,24 +117,23 @@ class OracleToleranceMixin(object):
|
||||||
|
|
||||||
|
|
||||||
class Area(OracleToleranceMixin, GeoFunc):
|
class Area(OracleToleranceMixin, GeoFunc):
|
||||||
|
output_field_class = AreaField
|
||||||
arity = 1
|
arity = 1
|
||||||
|
|
||||||
def as_sql(self, compiler, connection):
|
def as_sql(self, compiler, connection):
|
||||||
if connection.ops.geography:
|
if connection.ops.geography:
|
||||||
# Geography fields support area calculation, returns square meters.
|
self.output_field.area_att = 'sq_m'
|
||||||
self.output_field = AreaField('sq_m')
|
else:
|
||||||
elif not self.output_field.geodetic(connection):
|
|
||||||
# Getting the area units of the geographic field.
|
# Getting the area units of the geographic field.
|
||||||
units = self.output_field.units_name(connection)
|
source_fields = self.get_source_fields()
|
||||||
if units:
|
if len(source_fields):
|
||||||
self.output_field = AreaField(
|
source_field = source_fields[0]
|
||||||
AreaMeasure.unit_attname(self.output_field.units_name(connection))
|
if source_field.geodetic(connection):
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.output_field = FloatField()
|
|
||||||
else:
|
|
||||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||||
|
units_name = source_field.units_name(connection)
|
||||||
|
if units_name:
|
||||||
|
self.output_field.area_att = AreaMeasure.unit_attname(units_name)
|
||||||
return super(Area, self).as_sql(compiler, connection)
|
return super(Area, self).as_sql(compiler, connection)
|
||||||
|
|
||||||
def as_oracle(self, compiler, connection):
|
def as_oracle(self, compiler, connection):
|
||||||
|
|
|
@ -21,13 +21,14 @@ class BaseField(object):
|
||||||
|
|
||||||
class AreaField(BaseField):
|
class AreaField(BaseField):
|
||||||
"Wrapper for Area values."
|
"Wrapper for Area values."
|
||||||
def __init__(self, area_att):
|
def __init__(self, area_att=None):
|
||||||
self.area_att = area_att
|
self.area_att = area_att
|
||||||
|
|
||||||
def from_db_value(self, value, expression, connection, context):
|
def from_db_value(self, value, expression, connection, context):
|
||||||
if connection.features.interprets_empty_strings_as_nulls and value == '':
|
if connection.features.interprets_empty_strings_as_nulls and value == '':
|
||||||
value = None
|
value = None
|
||||||
if value is not None:
|
# If the units are known, convert value into area measure.
|
||||||
|
if value is not None and self.area_att:
|
||||||
value = Area(**{self.area_att: value})
|
value = Area(**{self.area_att: value})
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,10 @@ class Country(NamedModel):
|
||||||
mpoly = models.MultiPolygonField() # SRID, by default, is 4326
|
mpoly = models.MultiPolygonField() # SRID, by default, is 4326
|
||||||
|
|
||||||
|
|
||||||
|
class CountryWebMercator(NamedModel):
|
||||||
|
mpoly = models.MultiPolygonField(srid=3857)
|
||||||
|
|
||||||
|
|
||||||
class City(NamedModel):
|
class City(NamedModel):
|
||||||
point = models.PointField()
|
point = models.PointField()
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,14 @@ from decimal import Decimal
|
||||||
|
|
||||||
from django.contrib.gis.db.models import functions
|
from django.contrib.gis.db.models import functions
|
||||||
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
||||||
|
from django.contrib.gis.measure import Area
|
||||||
from django.db import connection
|
from django.db import connection
|
||||||
|
from django.db.models import Sum
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
from ..utils import mysql, oracle, postgis, spatialite
|
from ..utils import mysql, oracle, postgis, spatialite
|
||||||
from .models import City, Country, State, Track
|
from .models import City, Country, CountryWebMercator, State, Track
|
||||||
|
|
||||||
|
|
||||||
@skipUnlessDBFeature("gis_enabled")
|
@skipUnlessDBFeature("gis_enabled")
|
||||||
|
@ -231,6 +233,20 @@ class GISFunctionsTests(TestCase):
|
||||||
expected = c.mpoly.intersection(geom)
|
expected = c.mpoly.intersection(geom)
|
||||||
self.assertEqual(c.inter, expected)
|
self.assertEqual(c.inter, expected)
|
||||||
|
|
||||||
|
@skipUnlessDBFeature("has_Area_function")
|
||||||
|
def test_area_with_regular_aggregate(self):
|
||||||
|
# Create projected country objects, for this test to work on all backends.
|
||||||
|
for c in Country.objects.all():
|
||||||
|
CountryWebMercator.objects.create(name=c.name, mpoly=c.mpoly)
|
||||||
|
# Test in projected coordinate system
|
||||||
|
qs = CountryWebMercator.objects.annotate(area_sum=Sum(functions.Area('mpoly')))
|
||||||
|
for c in qs:
|
||||||
|
result = c.area_sum
|
||||||
|
# If the result is a measure object, get value.
|
||||||
|
if isinstance(result, Area):
|
||||||
|
result = result.sq_m
|
||||||
|
self.assertAlmostEqual((result - c.mpoly.area) / c.mpoly.area, 0)
|
||||||
|
|
||||||
@skipUnlessDBFeature("has_MemSize_function")
|
@skipUnlessDBFeature("has_MemSize_function")
|
||||||
def test_memsize(self):
|
def test_memsize(self):
|
||||||
ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')
|
ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')
|
||||||
|
|
Loading…
Reference in New Issue