mirror of https://github.com/django/django.git
Added MySQL support to GIS functions
This commit is contained in:
parent
44bdbbc316
commit
71e20814fc
|
@ -6,6 +6,8 @@ from django.db.backends.mysql.features import \
|
|||
class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures):
|
||||
has_spatialrefsys_table = False
|
||||
supports_add_srs_entry = False
|
||||
supports_distance_geodetic = False
|
||||
supports_length_geodetic = False
|
||||
supports_distances_lookups = False
|
||||
supports_transform = False
|
||||
supports_real_shape_operations = False
|
||||
|
|
|
@ -4,6 +4,7 @@ from django.contrib.gis.db.backends.base.operations import \
|
|||
from django.contrib.gis.db.backends.utils import SpatialOperator
|
||||
from django.contrib.gis.db.models import aggregates
|
||||
from django.db.backends.mysql.operations import DatabaseOperations
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
|
||||
|
@ -32,7 +33,28 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
|
|||
'within': SpatialOperator(func='MBRWithin'),
|
||||
}
|
||||
|
||||
disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)
|
||||
function_names = {
|
||||
'Distance': 'ST_Distance',
|
||||
'Length': 'GLength',
|
||||
'Union': 'ST_Union',
|
||||
}
|
||||
|
||||
disallowed_aggregates = (
|
||||
aggregates.Collect, aggregates.Extent, aggregates.Extent3D,
|
||||
aggregates.MakeLine, aggregates.Union,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def unsupported_functions(self):
|
||||
unsupported = {
|
||||
'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'BoundingCircle',
|
||||
'Difference', 'ForceRHR', 'GeoHash', 'Intersection', 'MemSize',
|
||||
'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', 'SnapToGrid',
|
||||
'SymDifference', 'Transform', 'Translate',
|
||||
}
|
||||
if self.connection.mysql_version < (5, 6, 1):
|
||||
unsupported.update({'Distance', 'Union'})
|
||||
return unsupported
|
||||
|
||||
def geo_db_type(self, f):
|
||||
return f.geom_type
|
||||
|
|
|
@ -169,7 +169,10 @@ class GeometryField(GeoSelectFormatMixin, Field):
|
|||
Returns true if this field's SRID corresponds with a coordinate
|
||||
system that uses non-projected units (e.g., latitude/longitude).
|
||||
"""
|
||||
return self.units_name(connection).lower() in self.geodetic_units
|
||||
units_name = self.units_name(connection)
|
||||
# Some backends like MySQL cannot determine units name. In that case,
|
||||
# test if srid is 4326 (WGS84), even if this is over-simplification.
|
||||
return units_name.lower() in self.geodetic_units if units_name else self.srid == 4326
|
||||
|
||||
def get_distance(self, value, lookup_type, connection):
|
||||
"""
|
||||
|
|
|
@ -79,6 +79,9 @@ class GeomValue(Value):
|
|||
self.value = connection.ops.Adapter(self.value)
|
||||
return super(GeomValue, self).as_sql(compiler, connection)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
|
||||
|
||||
|
@ -119,8 +122,12 @@ class Area(GeoFunc):
|
|||
self.output_field = AreaField('sq_m')
|
||||
elif not self.output_field.geodetic(connection):
|
||||
# Getting the area units of the geographic field.
|
||||
units = self.output_field.units_name(connection)
|
||||
if units:
|
||||
self.output_field = AreaField(
|
||||
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
|
||||
else:
|
||||
self.output_field = FloatField()
|
||||
else:
|
||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||
|
@ -198,8 +205,14 @@ class DistanceResultMixin(object):
|
|||
if geo_field.geodetic(connection):
|
||||
dist_att = 'm'
|
||||
else:
|
||||
dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection))
|
||||
units = geo_field.units_name(connection)
|
||||
if units:
|
||||
dist_att = DistanceMeasure.unit_attname(units)
|
||||
else:
|
||||
dist_att = None
|
||||
if dist_att:
|
||||
return DistanceMeasure(**{dist_att: value})
|
||||
return value
|
||||
|
||||
|
||||
class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
|
||||
|
@ -263,6 +276,12 @@ class Length(DistanceResultMixin, GeoFunc):
|
|||
self.spheroid = spheroid
|
||||
super(Length, self).__init__(expr1, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
||||
return super(Length, self).as_sql(compiler, connection)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||
src_field = self.get_source_fields()[0]
|
||||
|
|
|
@ -438,7 +438,8 @@ class DistanceFunctionsTests(TestCase):
|
|||
# Tolerance has to be lower for Oracle
|
||||
tol = 2
|
||||
for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')):
|
||||
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m, tol)
|
||||
# MySQL is returning a raw float value
|
||||
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m if hasattr(z.area, 'sq_m') else z.area, tol)
|
||||
|
||||
@skipUnlessDBFeature("has_Distance_function")
|
||||
def test_distance_simple(self):
|
||||
|
@ -624,12 +625,12 @@ class DistanceFunctionsTests(TestCase):
|
|||
# TODO: test with spheroid argument (True and False)
|
||||
else:
|
||||
# Does not support geodetic coordinate systems.
|
||||
with self.assertRaises(ValueError):
|
||||
Interstate.objects.annotate(length=Length('path'))
|
||||
with self.assertRaises(NotImplementedError):
|
||||
list(Interstate.objects.annotate(length=Length('path')))
|
||||
|
||||
# Now doing length on a projected coordinate system.
|
||||
i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10')
|
||||
self.assertAlmostEqual(len_m2, i10.length.m, 2)
|
||||
self.assertAlmostEqual(len_m2, i10.length.m if isinstance(i10.length, D) else i10.length, 2)
|
||||
self.assertTrue(
|
||||
SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
|
||||
)
|
||||
|
@ -652,7 +653,7 @@ class DistanceFunctionsTests(TestCase):
|
|||
for city in qs:
|
||||
self.assertEqual(0, city.perim.m)
|
||||
|
||||
@skipUnlessDBFeature("has_Area_function", "has_Distance_function")
|
||||
@skipUnlessDBFeature("supports_null_geometries", "has_Area_function", "has_Distance_function")
|
||||
def test_measurement_null_fields(self):
|
||||
"""
|
||||
Test the measurement functions on fields with NULL values.
|
||||
|
|
|
@ -9,7 +9,7 @@ from django.db import connection
|
|||
from django.test import TestCase, skipUnlessDBFeature
|
||||
from django.utils import six
|
||||
|
||||
from ..utils import oracle, postgis, spatialite
|
||||
from ..utils import mysql, oracle, postgis, spatialite
|
||||
|
||||
if HAS_GEOS:
|
||||
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
||||
|
@ -165,8 +165,8 @@ class GISFunctionsTests(TestCase):
|
|||
@skipUnlessDBFeature("has_Centroid_function")
|
||||
def test_centroid(self):
|
||||
qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly'))
|
||||
tol = 1.8 if mysql else (0.1 if oracle else 0.00001)
|
||||
for state in qs:
|
||||
tol = 0.1 # High tolerance due to oracle
|
||||
self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol))
|
||||
|
||||
@skipUnlessDBFeature("has_Difference_function")
|
||||
|
@ -248,9 +248,9 @@ class GISFunctionsTests(TestCase):
|
|||
qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point'))
|
||||
for city in qs:
|
||||
# Oracle and PostGIS 2.0+ will return 1 for the number of
|
||||
# geometries on non-collections, whereas PostGIS < 2.0.0
|
||||
# geometries on non-collections, whereas PostGIS < 2.0.0 and MySQL
|
||||
# will return None.
|
||||
if postgis and connection.ops.spatial_version < (2, 0, 0):
|
||||
if (postgis and connection.ops.spatial_version < (2, 0, 0)) or mysql:
|
||||
self.assertIsNone(city.num_geom)
|
||||
else:
|
||||
self.assertEqual(1, city.num_geom)
|
||||
|
@ -261,8 +261,8 @@ class GISFunctionsTests(TestCase):
|
|||
Track.objects.create(name='Foo', line=LineString(coords))
|
||||
qs = Track.objects.annotate(num_points=functions.NumPoints('line'))
|
||||
self.assertEqual(qs.first().num_points, 2)
|
||||
if spatialite:
|
||||
# Spatialite can only count points on LineStrings
|
||||
if spatialite or mysql:
|
||||
# Spatialite and MySQL can only count points on LineStrings
|
||||
return
|
||||
|
||||
for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')):
|
||||
|
@ -455,5 +455,7 @@ class GISFunctionsTests(TestCase):
|
|||
geom = Point(-95.363151, 29.763374, srid=4326)
|
||||
ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas')
|
||||
tol = 0.00001
|
||||
expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
|
||||
self.assertTrue(expected.equals_exact(ptown.union, tol))
|
||||
# Undefined ordering
|
||||
expected1 = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
|
||||
expected2 = fromstr('MULTIPOINT(-95.363151 29.763374,-96.801611 32.782057)', srid=4326)
|
||||
self.assertTrue(expected1.equals_exact(ptown.union, tol) or expected2.equals_exact(ptown.union, tol))
|
||||
|
|
Loading…
Reference in New Issue