Added MySQL support to GIS functions

This commit is contained in:
Claude Paroz 2015-01-25 00:03:02 +01:00
parent 44bdbbc316
commit 71e20814fc
6 changed files with 68 additions and 19 deletions

View File

@ -6,6 +6,8 @@ from django.db.backends.mysql.features import \
class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures):
has_spatialrefsys_table = False
supports_add_srs_entry = False
supports_distance_geodetic = False
supports_length_geodetic = False
supports_distances_lookups = False
supports_transform = False
supports_real_shape_operations = False

View File

@ -4,6 +4,7 @@ from django.contrib.gis.db.backends.base.operations import \
from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates
from django.db.backends.mysql.operations import DatabaseOperations
from django.utils.functional import cached_property
class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
@ -32,7 +33,28 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
'within': SpatialOperator(func='MBRWithin'),
}
disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)
function_names = {
'Distance': 'ST_Distance',
'Length': 'GLength',
'Union': 'ST_Union',
}
disallowed_aggregates = (
aggregates.Collect, aggregates.Extent, aggregates.Extent3D,
aggregates.MakeLine, aggregates.Union,
)
@cached_property
def unsupported_functions(self):
unsupported = {
'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'BoundingCircle',
'Difference', 'ForceRHR', 'GeoHash', 'Intersection', 'MemSize',
'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', 'SnapToGrid',
'SymDifference', 'Transform', 'Translate',
}
if self.connection.mysql_version < (5, 6, 1):
unsupported.update({'Distance', 'Union'})
return unsupported
def geo_db_type(self, f):
return f.geom_type

View File

@ -169,7 +169,10 @@ class GeometryField(GeoSelectFormatMixin, Field):
Returns true if this field's SRID corresponds with a coordinate
system that uses non-projected units (e.g., latitude/longitude).
"""
return self.units_name(connection).lower() in self.geodetic_units
units_name = self.units_name(connection)
# Some backends like MySQL cannot determine units name. In that case,
# test if srid is 4326 (WGS84), even if this is over-simplification.
return units_name.lower() in self.geodetic_units if units_name else self.srid == 4326
def get_distance(self, value, lookup_type, connection):
"""

View File

@ -79,6 +79,9 @@ class GeomValue(Value):
self.value = connection.ops.Adapter(self.value)
return super(GeomValue, self).as_sql(compiler, connection)
def as_mysql(self, compiler, connection):
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
def as_sqlite(self, compiler, connection):
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
@ -119,8 +122,12 @@ class Area(GeoFunc):
self.output_field = AreaField('sq_m')
elif not self.output_field.geodetic(connection):
# Getting the area units of the geographic field.
self.output_field = AreaField(
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
units = self.output_field.units_name(connection)
if units:
self.output_field = AreaField(
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
else:
self.output_field = FloatField()
else:
# TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
@ -198,8 +205,14 @@ class DistanceResultMixin(object):
if geo_field.geodetic(connection):
dist_att = 'm'
else:
dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection))
return DistanceMeasure(**{dist_att: value})
units = geo_field.units_name(connection)
if units:
dist_att = DistanceMeasure.unit_attname(units)
else:
dist_att = None
if dist_att:
return DistanceMeasure(**{dist_att: value})
return value
class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
@ -263,6 +276,12 @@ class Length(DistanceResultMixin, GeoFunc):
self.spheroid = spheroid
super(Length, self).__init__(expr1, **extra)
def as_sql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super(Length, self).as_sql(compiler, connection)
def as_postgresql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
src_field = self.get_source_fields()[0]

View File

@ -438,7 +438,8 @@ class DistanceFunctionsTests(TestCase):
# Tolerance has to be lower for Oracle
tol = 2
for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')):
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m, tol)
# MySQL is returning a raw float value
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m if hasattr(z.area, 'sq_m') else z.area, tol)
@skipUnlessDBFeature("has_Distance_function")
def test_distance_simple(self):
@ -624,12 +625,12 @@ class DistanceFunctionsTests(TestCase):
# TODO: test with spheroid argument (True and False)
else:
# Does not support geodetic coordinate systems.
with self.assertRaises(ValueError):
Interstate.objects.annotate(length=Length('path'))
with self.assertRaises(NotImplementedError):
list(Interstate.objects.annotate(length=Length('path')))
# Now doing length on a projected coordinate system.
i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10')
self.assertAlmostEqual(len_m2, i10.length.m, 2)
self.assertAlmostEqual(len_m2, i10.length.m if isinstance(i10.length, D) else i10.length, 2)
self.assertTrue(
SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
)
@ -652,7 +653,7 @@ class DistanceFunctionsTests(TestCase):
for city in qs:
self.assertEqual(0, city.perim.m)
@skipUnlessDBFeature("has_Area_function", "has_Distance_function")
@skipUnlessDBFeature("supports_null_geometries", "has_Area_function", "has_Distance_function")
def test_measurement_null_fields(self):
"""
Test the measurement functions on fields with NULL values.

View File

@ -9,7 +9,7 @@ from django.db import connection
from django.test import TestCase, skipUnlessDBFeature
from django.utils import six
from ..utils import oracle, postgis, spatialite
from ..utils import mysql, oracle, postgis, spatialite
if HAS_GEOS:
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
@ -165,8 +165,8 @@ class GISFunctionsTests(TestCase):
@skipUnlessDBFeature("has_Centroid_function")
def test_centroid(self):
qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly'))
tol = 1.8 if mysql else (0.1 if oracle else 0.00001)
for state in qs:
tol = 0.1 # High tolerance due to oracle
self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol))
@skipUnlessDBFeature("has_Difference_function")
@ -248,9 +248,9 @@ class GISFunctionsTests(TestCase):
qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point'))
for city in qs:
# Oracle and PostGIS 2.0+ will return 1 for the number of
# geometries on non-collections, whereas PostGIS < 2.0.0
# geometries on non-collections, whereas PostGIS < 2.0.0 and MySQL
# will return None.
if postgis and connection.ops.spatial_version < (2, 0, 0):
if (postgis and connection.ops.spatial_version < (2, 0, 0)) or mysql:
self.assertIsNone(city.num_geom)
else:
self.assertEqual(1, city.num_geom)
@ -261,8 +261,8 @@ class GISFunctionsTests(TestCase):
Track.objects.create(name='Foo', line=LineString(coords))
qs = Track.objects.annotate(num_points=functions.NumPoints('line'))
self.assertEqual(qs.first().num_points, 2)
if spatialite:
# Spatialite can only count points on LineStrings
if spatialite or mysql:
# Spatialite and MySQL can only count points on LineStrings
return
for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')):
@ -455,5 +455,7 @@ class GISFunctionsTests(TestCase):
geom = Point(-95.363151, 29.763374, srid=4326)
ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas')
tol = 0.00001
expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
self.assertTrue(expected.equals_exact(ptown.union, tol))
# Undefined ordering
expected1 = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
expected2 = fromstr('MULTIPOINT(-95.363151 29.763374,-96.801611 32.782057)', srid=4326)
self.assertTrue(expected1.equals_exact(ptown.union, tol) or expected2.equals_exact(ptown.union, tol))