From 71e20814fcb5983bdc96a6b15765b8f6abd11542 Mon Sep 17 00:00:00 2001 From: Claude Paroz Date: Sun, 25 Jan 2015 00:03:02 +0100 Subject: [PATCH] Added MySQL support to GIS functions --- .../contrib/gis/db/backends/mysql/features.py | 2 ++ .../gis/db/backends/mysql/operations.py | 24 ++++++++++++++++- django/contrib/gis/db/models/fields.py | 5 +++- django/contrib/gis/db/models/functions.py | 27 ++++++++++++++++--- tests/gis_tests/distapp/tests.py | 11 ++++---- tests/gis_tests/geoapp/test_functions.py | 18 +++++++------ 6 files changed, 68 insertions(+), 19 deletions(-) diff --git a/django/contrib/gis/db/backends/mysql/features.py b/django/contrib/gis/db/backends/mysql/features.py index a547ec967a..ce5893f0a8 100644 --- a/django/contrib/gis/db/backends/mysql/features.py +++ b/django/contrib/gis/db/backends/mysql/features.py @@ -6,6 +6,8 @@ from django.db.backends.mysql.features import \ class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures): has_spatialrefsys_table = False supports_add_srs_entry = False + supports_distance_geodetic = False + supports_length_geodetic = False supports_distances_lookups = False supports_transform = False supports_real_shape_operations = False diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index 130c103ac2..cfd780022f 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -4,6 +4,7 @@ from django.contrib.gis.db.backends.base.operations import \ from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.models import aggregates from django.db.backends.mysql.operations import DatabaseOperations +from django.utils.functional import cached_property class MySQLOperations(BaseSpatialOperations, DatabaseOperations): @@ -32,7 +33,28 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): 'within': SpatialOperator(func='MBRWithin'), } - disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union) + function_names = { + 'Distance': 'ST_Distance', + 'Length': 'GLength', + 'Union': 'ST_Union', + } + + disallowed_aggregates = ( + aggregates.Collect, aggregates.Extent, aggregates.Extent3D, + aggregates.MakeLine, aggregates.Union, + ) + + @cached_property + def unsupported_functions(self): + unsupported = { + 'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'BoundingCircle', + 'Difference', 'ForceRHR', 'GeoHash', 'Intersection', 'MemSize', + 'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', 'SnapToGrid', + 'SymDifference', 'Transform', 'Translate', + } + if self.connection.mysql_version < (5, 6, 1): + unsupported.update({'Distance', 'Union'}) + return unsupported def geo_db_type(self, f): return f.geom_type diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index ad39de1dbe..6d5c2b2b6a 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -169,7 +169,10 @@ class GeometryField(GeoSelectFormatMixin, Field): Returns true if this field's SRID corresponds with a coordinate system that uses non-projected units (e.g., latitude/longitude). """ - return self.units_name(connection).lower() in self.geodetic_units + units_name = self.units_name(connection) + # Some backends like MySQL cannot determine units name. In that case, + # test if srid is 4326 (WGS84), even if this is over-simplification. + return units_name.lower() in self.geodetic_units if units_name else self.srid == 4326 def get_distance(self, value, lookup_type, connection): """ diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index b23ad10d20..4902520b86 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -79,6 +79,9 @@ class GeomValue(Value): self.value = connection.ops.Adapter(self.value) return super(GeomValue, self).as_sql(compiler, connection) + def as_mysql(self, compiler, connection): + return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)] + def as_sqlite(self, compiler, connection): return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)] @@ -119,8 +122,12 @@ class Area(GeoFunc): self.output_field = AreaField('sq_m') elif not self.output_field.geodetic(connection): # Getting the area units of the geographic field. - self.output_field = AreaField( - AreaMeasure.unit_attname(self.output_field.units_name(connection))) + units = self.output_field.units_name(connection) + if units: + self.output_field = AreaField( + AreaMeasure.unit_attname(self.output_field.units_name(connection))) + else: + self.output_field = FloatField() else: # TODO: Do we want to support raw number areas for geodetic fields? raise NotImplementedError('Area on geodetic coordinate systems not supported.') @@ -198,8 +205,14 @@ class DistanceResultMixin(object): if geo_field.geodetic(connection): dist_att = 'm' else: - dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection)) - return DistanceMeasure(**{dist_att: value}) + units = geo_field.units_name(connection) + if units: + dist_att = DistanceMeasure.unit_attname(units) + else: + dist_att = None + if dist_att: + return DistanceMeasure(**{dist_att: value}) + return value class Distance(DistanceResultMixin, GeoFuncWithGeoParam): @@ -263,6 +276,12 @@ class Length(DistanceResultMixin, GeoFunc): self.spheroid = spheroid super(Length, self).__init__(expr1, **extra) + def as_sql(self, compiler, connection): + geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info + if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic: + raise NotImplementedError("This backend doesn't support Length on geodetic fields") + return super(Length, self).as_sql(compiler, connection) + def as_postgresql(self, compiler, connection): geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info src_field = self.get_source_fields()[0] diff --git a/tests/gis_tests/distapp/tests.py b/tests/gis_tests/distapp/tests.py index c59772a782..e6d66d22ae 100644 --- a/tests/gis_tests/distapp/tests.py +++ b/tests/gis_tests/distapp/tests.py @@ -438,7 +438,8 @@ class DistanceFunctionsTests(TestCase): # Tolerance has to be lower for Oracle tol = 2 for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')): - self.assertAlmostEqual(area_sq_m[i], z.area.sq_m, tol) + # MySQL is returning a raw float value + self.assertAlmostEqual(area_sq_m[i], z.area.sq_m if hasattr(z.area, 'sq_m') else z.area, tol) @skipUnlessDBFeature("has_Distance_function") def test_distance_simple(self): @@ -624,12 +625,12 @@ class DistanceFunctionsTests(TestCase): # TODO: test with spheroid argument (True and False) else: # Does not support geodetic coordinate systems. - with self.assertRaises(ValueError): - Interstate.objects.annotate(length=Length('path')) + with self.assertRaises(NotImplementedError): + list(Interstate.objects.annotate(length=Length('path'))) # Now doing length on a projected coordinate system. i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10') - self.assertAlmostEqual(len_m2, i10.length.m, 2) + self.assertAlmostEqual(len_m2, i10.length.m if isinstance(i10.length, D) else i10.length, 2) self.assertTrue( SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists() ) @@ -652,7 +653,7 @@ class DistanceFunctionsTests(TestCase): for city in qs: self.assertEqual(0, city.perim.m) - @skipUnlessDBFeature("has_Area_function", "has_Distance_function") + @skipUnlessDBFeature("supports_null_geometries", "has_Area_function", "has_Distance_function") def test_measurement_null_fields(self): """ Test the measurement functions on fields with NULL values. diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py index fb2a88be09..b17a1a2fc0 100644 --- a/tests/gis_tests/geoapp/test_functions.py +++ b/tests/gis_tests/geoapp/test_functions.py @@ -9,7 +9,7 @@ from django.db import connection from django.test import TestCase, skipUnlessDBFeature from django.utils import six -from ..utils import oracle, postgis, spatialite +from ..utils import mysql, oracle, postgis, spatialite if HAS_GEOS: from django.contrib.gis.geos import LineString, Point, Polygon, fromstr @@ -165,8 +165,8 @@ class GISFunctionsTests(TestCase): @skipUnlessDBFeature("has_Centroid_function") def test_centroid(self): qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly')) + tol = 1.8 if mysql else (0.1 if oracle else 0.00001) for state in qs: - tol = 0.1 # High tolerance due to oracle self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol)) @skipUnlessDBFeature("has_Difference_function") @@ -248,9 +248,9 @@ class GISFunctionsTests(TestCase): qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point')) for city in qs: # Oracle and PostGIS 2.0+ will return 1 for the number of - # geometries on non-collections, whereas PostGIS < 2.0.0 + # geometries on non-collections, whereas PostGIS < 2.0.0 and MySQL # will return None. - if postgis and connection.ops.spatial_version < (2, 0, 0): + if (postgis and connection.ops.spatial_version < (2, 0, 0)) or mysql: self.assertIsNone(city.num_geom) else: self.assertEqual(1, city.num_geom) @@ -261,8 +261,8 @@ class GISFunctionsTests(TestCase): Track.objects.create(name='Foo', line=LineString(coords)) qs = Track.objects.annotate(num_points=functions.NumPoints('line')) self.assertEqual(qs.first().num_points, 2) - if spatialite: - # Spatialite can only count points on LineStrings + if spatialite or mysql: + # Spatialite and MySQL can only count points on LineStrings return for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')): @@ -455,5 +455,7 @@ class GISFunctionsTests(TestCase): geom = Point(-95.363151, 29.763374, srid=4326) ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas') tol = 0.00001 - expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326) - self.assertTrue(expected.equals_exact(ptown.union, tol)) + # Undefined ordering + expected1 = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326) + expected2 = fromstr('MULTIPOINT(-95.363151 29.763374,-96.801611 32.782057)', srid=4326) + self.assertTrue(expected1.equals_exact(ptown.union, tol) or expected2.equals_exact(ptown.union, tol))