Added MySQL support to GIS functions

This commit is contained in:
Claude Paroz 2015-01-25 00:03:02 +01:00
parent 44bdbbc316
commit 71e20814fc
6 changed files with 68 additions and 19 deletions

View File

@ -6,6 +6,8 @@ from django.db.backends.mysql.features import \
class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures): class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures):
has_spatialrefsys_table = False has_spatialrefsys_table = False
supports_add_srs_entry = False supports_add_srs_entry = False
supports_distance_geodetic = False
supports_length_geodetic = False
supports_distances_lookups = False supports_distances_lookups = False
supports_transform = False supports_transform = False
supports_real_shape_operations = False supports_real_shape_operations = False

View File

@ -4,6 +4,7 @@ from django.contrib.gis.db.backends.base.operations import \
from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.backends.utils import SpatialOperator
from django.contrib.gis.db.models import aggregates from django.contrib.gis.db.models import aggregates
from django.db.backends.mysql.operations import DatabaseOperations from django.db.backends.mysql.operations import DatabaseOperations
from django.utils.functional import cached_property
class MySQLOperations(BaseSpatialOperations, DatabaseOperations): class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
@ -32,7 +33,28 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
'within': SpatialOperator(func='MBRWithin'), 'within': SpatialOperator(func='MBRWithin'),
} }
disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union) function_names = {
'Distance': 'ST_Distance',
'Length': 'GLength',
'Union': 'ST_Union',
}
disallowed_aggregates = (
aggregates.Collect, aggregates.Extent, aggregates.Extent3D,
aggregates.MakeLine, aggregates.Union,
)
@cached_property
def unsupported_functions(self):
unsupported = {
'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'BoundingCircle',
'Difference', 'ForceRHR', 'GeoHash', 'Intersection', 'MemSize',
'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', 'SnapToGrid',
'SymDifference', 'Transform', 'Translate',
}
if self.connection.mysql_version < (5, 6, 1):
unsupported.update({'Distance', 'Union'})
return unsupported
def geo_db_type(self, f): def geo_db_type(self, f):
return f.geom_type return f.geom_type

View File

@ -169,7 +169,10 @@ class GeometryField(GeoSelectFormatMixin, Field):
Returns true if this field's SRID corresponds with a coordinate Returns true if this field's SRID corresponds with a coordinate
system that uses non-projected units (e.g., latitude/longitude). system that uses non-projected units (e.g., latitude/longitude).
""" """
return self.units_name(connection).lower() in self.geodetic_units units_name = self.units_name(connection)
# Some backends like MySQL cannot determine units name. In that case,
# test if srid is 4326 (WGS84), even if this is over-simplification.
return units_name.lower() in self.geodetic_units if units_name else self.srid == 4326
def get_distance(self, value, lookup_type, connection): def get_distance(self, value, lookup_type, connection):
""" """

View File

@ -79,6 +79,9 @@ class GeomValue(Value):
self.value = connection.ops.Adapter(self.value) self.value = connection.ops.Adapter(self.value)
return super(GeomValue, self).as_sql(compiler, connection) return super(GeomValue, self).as_sql(compiler, connection)
def as_mysql(self, compiler, connection):
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
def as_sqlite(self, compiler, connection): def as_sqlite(self, compiler, connection):
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)] return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
@ -119,8 +122,12 @@ class Area(GeoFunc):
self.output_field = AreaField('sq_m') self.output_field = AreaField('sq_m')
elif not self.output_field.geodetic(connection): elif not self.output_field.geodetic(connection):
# Getting the area units of the geographic field. # Getting the area units of the geographic field.
self.output_field = AreaField( units = self.output_field.units_name(connection)
AreaMeasure.unit_attname(self.output_field.units_name(connection))) if units:
self.output_field = AreaField(
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
else:
self.output_field = FloatField()
else: else:
# TODO: Do we want to support raw number areas for geodetic fields? # TODO: Do we want to support raw number areas for geodetic fields?
raise NotImplementedError('Area on geodetic coordinate systems not supported.') raise NotImplementedError('Area on geodetic coordinate systems not supported.')
@ -198,8 +205,14 @@ class DistanceResultMixin(object):
if geo_field.geodetic(connection): if geo_field.geodetic(connection):
dist_att = 'm' dist_att = 'm'
else: else:
dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection)) units = geo_field.units_name(connection)
return DistanceMeasure(**{dist_att: value}) if units:
dist_att = DistanceMeasure.unit_attname(units)
else:
dist_att = None
if dist_att:
return DistanceMeasure(**{dist_att: value})
return value
class Distance(DistanceResultMixin, GeoFuncWithGeoParam): class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
@ -263,6 +276,12 @@ class Length(DistanceResultMixin, GeoFunc):
self.spheroid = spheroid self.spheroid = spheroid
super(Length, self).__init__(expr1, **extra) super(Length, self).__init__(expr1, **extra)
def as_sql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
return super(Length, self).as_sql(compiler, connection)
def as_postgresql(self, compiler, connection): def as_postgresql(self, compiler, connection):
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
src_field = self.get_source_fields()[0] src_field = self.get_source_fields()[0]

View File

@ -438,7 +438,8 @@ class DistanceFunctionsTests(TestCase):
# Tolerance has to be lower for Oracle # Tolerance has to be lower for Oracle
tol = 2 tol = 2
for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')): for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')):
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m, tol) # MySQL is returning a raw float value
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m if hasattr(z.area, 'sq_m') else z.area, tol)
@skipUnlessDBFeature("has_Distance_function") @skipUnlessDBFeature("has_Distance_function")
def test_distance_simple(self): def test_distance_simple(self):
@ -624,12 +625,12 @@ class DistanceFunctionsTests(TestCase):
# TODO: test with spheroid argument (True and False) # TODO: test with spheroid argument (True and False)
else: else:
# Does not support geodetic coordinate systems. # Does not support geodetic coordinate systems.
with self.assertRaises(ValueError): with self.assertRaises(NotImplementedError):
Interstate.objects.annotate(length=Length('path')) list(Interstate.objects.annotate(length=Length('path')))
# Now doing length on a projected coordinate system. # Now doing length on a projected coordinate system.
i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10') i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10')
self.assertAlmostEqual(len_m2, i10.length.m, 2) self.assertAlmostEqual(len_m2, i10.length.m if isinstance(i10.length, D) else i10.length, 2)
self.assertTrue( self.assertTrue(
SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists() SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
) )
@ -652,7 +653,7 @@ class DistanceFunctionsTests(TestCase):
for city in qs: for city in qs:
self.assertEqual(0, city.perim.m) self.assertEqual(0, city.perim.m)
@skipUnlessDBFeature("has_Area_function", "has_Distance_function") @skipUnlessDBFeature("supports_null_geometries", "has_Area_function", "has_Distance_function")
def test_measurement_null_fields(self): def test_measurement_null_fields(self):
""" """
Test the measurement functions on fields with NULL values. Test the measurement functions on fields with NULL values.

View File

@ -9,7 +9,7 @@ from django.db import connection
from django.test import TestCase, skipUnlessDBFeature from django.test import TestCase, skipUnlessDBFeature
from django.utils import six from django.utils import six
from ..utils import oracle, postgis, spatialite from ..utils import mysql, oracle, postgis, spatialite
if HAS_GEOS: if HAS_GEOS:
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
@ -165,8 +165,8 @@ class GISFunctionsTests(TestCase):
@skipUnlessDBFeature("has_Centroid_function") @skipUnlessDBFeature("has_Centroid_function")
def test_centroid(self): def test_centroid(self):
qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly')) qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly'))
tol = 1.8 if mysql else (0.1 if oracle else 0.00001)
for state in qs: for state in qs:
tol = 0.1 # High tolerance due to oracle
self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol)) self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol))
@skipUnlessDBFeature("has_Difference_function") @skipUnlessDBFeature("has_Difference_function")
@ -248,9 +248,9 @@ class GISFunctionsTests(TestCase):
qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point')) qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point'))
for city in qs: for city in qs:
# Oracle and PostGIS 2.0+ will return 1 for the number of # Oracle and PostGIS 2.0+ will return 1 for the number of
# geometries on non-collections, whereas PostGIS < 2.0.0 # geometries on non-collections, whereas PostGIS < 2.0.0 and MySQL
# will return None. # will return None.
if postgis and connection.ops.spatial_version < (2, 0, 0): if (postgis and connection.ops.spatial_version < (2, 0, 0)) or mysql:
self.assertIsNone(city.num_geom) self.assertIsNone(city.num_geom)
else: else:
self.assertEqual(1, city.num_geom) self.assertEqual(1, city.num_geom)
@ -261,8 +261,8 @@ class GISFunctionsTests(TestCase):
Track.objects.create(name='Foo', line=LineString(coords)) Track.objects.create(name='Foo', line=LineString(coords))
qs = Track.objects.annotate(num_points=functions.NumPoints('line')) qs = Track.objects.annotate(num_points=functions.NumPoints('line'))
self.assertEqual(qs.first().num_points, 2) self.assertEqual(qs.first().num_points, 2)
if spatialite: if spatialite or mysql:
# Spatialite can only count points on LineStrings # Spatialite and MySQL can only count points on LineStrings
return return
for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')): for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')):
@ -455,5 +455,7 @@ class GISFunctionsTests(TestCase):
geom = Point(-95.363151, 29.763374, srid=4326) geom = Point(-95.363151, 29.763374, srid=4326)
ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas') ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas')
tol = 0.00001 tol = 0.00001
expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326) # Undefined ordering
self.assertTrue(expected.equals_exact(ptown.union, tol)) expected1 = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
expected2 = fromstr('MULTIPOINT(-95.363151 29.763374,-96.801611 32.782057)', srid=4326)
self.assertTrue(expected1.equals_exact(ptown.union, tol) or expected2.equals_exact(ptown.union, tol))