mirror of https://github.com/django/django.git
Added MySQL support to GIS functions
This commit is contained in:
parent
44bdbbc316
commit
71e20814fc
|
@ -6,6 +6,8 @@ from django.db.backends.mysql.features import \
|
||||||
class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures):
|
class DatabaseFeatures(BaseSpatialFeatures, MySQLDatabaseFeatures):
|
||||||
has_spatialrefsys_table = False
|
has_spatialrefsys_table = False
|
||||||
supports_add_srs_entry = False
|
supports_add_srs_entry = False
|
||||||
|
supports_distance_geodetic = False
|
||||||
|
supports_length_geodetic = False
|
||||||
supports_distances_lookups = False
|
supports_distances_lookups = False
|
||||||
supports_transform = False
|
supports_transform = False
|
||||||
supports_real_shape_operations = False
|
supports_real_shape_operations = False
|
||||||
|
|
|
@ -4,6 +4,7 @@ from django.contrib.gis.db.backends.base.operations import \
|
||||||
from django.contrib.gis.db.backends.utils import SpatialOperator
|
from django.contrib.gis.db.backends.utils import SpatialOperator
|
||||||
from django.contrib.gis.db.models import aggregates
|
from django.contrib.gis.db.models import aggregates
|
||||||
from django.db.backends.mysql.operations import DatabaseOperations
|
from django.db.backends.mysql.operations import DatabaseOperations
|
||||||
|
from django.utils.functional import cached_property
|
||||||
|
|
||||||
|
|
||||||
class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
|
class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
|
@ -32,7 +33,28 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations):
|
||||||
'within': SpatialOperator(func='MBRWithin'),
|
'within': SpatialOperator(func='MBRWithin'),
|
||||||
}
|
}
|
||||||
|
|
||||||
disallowed_aggregates = (aggregates.Collect, aggregates.Extent, aggregates.Extent3D, aggregates.MakeLine, aggregates.Union)
|
function_names = {
|
||||||
|
'Distance': 'ST_Distance',
|
||||||
|
'Length': 'GLength',
|
||||||
|
'Union': 'ST_Union',
|
||||||
|
}
|
||||||
|
|
||||||
|
disallowed_aggregates = (
|
||||||
|
aggregates.Collect, aggregates.Extent, aggregates.Extent3D,
|
||||||
|
aggregates.MakeLine, aggregates.Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def unsupported_functions(self):
|
||||||
|
unsupported = {
|
||||||
|
'AsGeoJSON', 'AsGML', 'AsKML', 'AsSVG', 'BoundingCircle',
|
||||||
|
'Difference', 'ForceRHR', 'GeoHash', 'Intersection', 'MemSize',
|
||||||
|
'Perimeter', 'PointOnSurface', 'Reverse', 'Scale', 'SnapToGrid',
|
||||||
|
'SymDifference', 'Transform', 'Translate',
|
||||||
|
}
|
||||||
|
if self.connection.mysql_version < (5, 6, 1):
|
||||||
|
unsupported.update({'Distance', 'Union'})
|
||||||
|
return unsupported
|
||||||
|
|
||||||
def geo_db_type(self, f):
|
def geo_db_type(self, f):
|
||||||
return f.geom_type
|
return f.geom_type
|
||||||
|
|
|
@ -169,7 +169,10 @@ class GeometryField(GeoSelectFormatMixin, Field):
|
||||||
Returns true if this field's SRID corresponds with a coordinate
|
Returns true if this field's SRID corresponds with a coordinate
|
||||||
system that uses non-projected units (e.g., latitude/longitude).
|
system that uses non-projected units (e.g., latitude/longitude).
|
||||||
"""
|
"""
|
||||||
return self.units_name(connection).lower() in self.geodetic_units
|
units_name = self.units_name(connection)
|
||||||
|
# Some backends like MySQL cannot determine units name. In that case,
|
||||||
|
# test if srid is 4326 (WGS84), even if this is over-simplification.
|
||||||
|
return units_name.lower() in self.geodetic_units if units_name else self.srid == 4326
|
||||||
|
|
||||||
def get_distance(self, value, lookup_type, connection):
|
def get_distance(self, value, lookup_type, connection):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -79,6 +79,9 @@ class GeomValue(Value):
|
||||||
self.value = connection.ops.Adapter(self.value)
|
self.value = connection.ops.Adapter(self.value)
|
||||||
return super(GeomValue, self).as_sql(compiler, connection)
|
return super(GeomValue, self).as_sql(compiler, connection)
|
||||||
|
|
||||||
|
def as_mysql(self, compiler, connection):
|
||||||
|
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
|
||||||
|
|
||||||
def as_sqlite(self, compiler, connection):
|
def as_sqlite(self, compiler, connection):
|
||||||
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
|
return 'GeomFromText(%%s, %s)' % self.srid, [connection.ops.Adapter(self.value)]
|
||||||
|
|
||||||
|
@ -119,8 +122,12 @@ class Area(GeoFunc):
|
||||||
self.output_field = AreaField('sq_m')
|
self.output_field = AreaField('sq_m')
|
||||||
elif not self.output_field.geodetic(connection):
|
elif not self.output_field.geodetic(connection):
|
||||||
# Getting the area units of the geographic field.
|
# Getting the area units of the geographic field.
|
||||||
|
units = self.output_field.units_name(connection)
|
||||||
|
if units:
|
||||||
self.output_field = AreaField(
|
self.output_field = AreaField(
|
||||||
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
|
AreaMeasure.unit_attname(self.output_field.units_name(connection)))
|
||||||
|
else:
|
||||||
|
self.output_field = FloatField()
|
||||||
else:
|
else:
|
||||||
# TODO: Do we want to support raw number areas for geodetic fields?
|
# TODO: Do we want to support raw number areas for geodetic fields?
|
||||||
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
raise NotImplementedError('Area on geodetic coordinate systems not supported.')
|
||||||
|
@ -198,8 +205,14 @@ class DistanceResultMixin(object):
|
||||||
if geo_field.geodetic(connection):
|
if geo_field.geodetic(connection):
|
||||||
dist_att = 'm'
|
dist_att = 'm'
|
||||||
else:
|
else:
|
||||||
dist_att = DistanceMeasure.unit_attname(geo_field.units_name(connection))
|
units = geo_field.units_name(connection)
|
||||||
|
if units:
|
||||||
|
dist_att = DistanceMeasure.unit_attname(units)
|
||||||
|
else:
|
||||||
|
dist_att = None
|
||||||
|
if dist_att:
|
||||||
return DistanceMeasure(**{dist_att: value})
|
return DistanceMeasure(**{dist_att: value})
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
|
class Distance(DistanceResultMixin, GeoFuncWithGeoParam):
|
||||||
|
@ -263,6 +276,12 @@ class Length(DistanceResultMixin, GeoFunc):
|
||||||
self.spheroid = spheroid
|
self.spheroid = spheroid
|
||||||
super(Length, self).__init__(expr1, **extra)
|
super(Length, self).__init__(expr1, **extra)
|
||||||
|
|
||||||
|
def as_sql(self, compiler, connection):
|
||||||
|
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||||
|
if geo_field.geodetic(connection) and not connection.features.supports_length_geodetic:
|
||||||
|
raise NotImplementedError("This backend doesn't support Length on geodetic fields")
|
||||||
|
return super(Length, self).as_sql(compiler, connection)
|
||||||
|
|
||||||
def as_postgresql(self, compiler, connection):
|
def as_postgresql(self, compiler, connection):
|
||||||
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
geo_field = GeometryField(srid=self.srid) # Fake field to get SRID info
|
||||||
src_field = self.get_source_fields()[0]
|
src_field = self.get_source_fields()[0]
|
||||||
|
|
|
@ -438,7 +438,8 @@ class DistanceFunctionsTests(TestCase):
|
||||||
# Tolerance has to be lower for Oracle
|
# Tolerance has to be lower for Oracle
|
||||||
tol = 2
|
tol = 2
|
||||||
for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')):
|
for i, z in enumerate(SouthTexasZipcode.objects.annotate(area=Area('poly')).order_by('name')):
|
||||||
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m, tol)
|
# MySQL is returning a raw float value
|
||||||
|
self.assertAlmostEqual(area_sq_m[i], z.area.sq_m if hasattr(z.area, 'sq_m') else z.area, tol)
|
||||||
|
|
||||||
@skipUnlessDBFeature("has_Distance_function")
|
@skipUnlessDBFeature("has_Distance_function")
|
||||||
def test_distance_simple(self):
|
def test_distance_simple(self):
|
||||||
|
@ -624,12 +625,12 @@ class DistanceFunctionsTests(TestCase):
|
||||||
# TODO: test with spheroid argument (True and False)
|
# TODO: test with spheroid argument (True and False)
|
||||||
else:
|
else:
|
||||||
# Does not support geodetic coordinate systems.
|
# Does not support geodetic coordinate systems.
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(NotImplementedError):
|
||||||
Interstate.objects.annotate(length=Length('path'))
|
list(Interstate.objects.annotate(length=Length('path')))
|
||||||
|
|
||||||
# Now doing length on a projected coordinate system.
|
# Now doing length on a projected coordinate system.
|
||||||
i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10')
|
i10 = SouthTexasInterstate.objects.annotate(length=Length('path')).get(name='I-10')
|
||||||
self.assertAlmostEqual(len_m2, i10.length.m, 2)
|
self.assertAlmostEqual(len_m2, i10.length.m if isinstance(i10.length, D) else i10.length, 2)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
|
SouthTexasInterstate.objects.annotate(length=Length('path')).filter(length__gt=4000).exists()
|
||||||
)
|
)
|
||||||
|
@ -652,7 +653,7 @@ class DistanceFunctionsTests(TestCase):
|
||||||
for city in qs:
|
for city in qs:
|
||||||
self.assertEqual(0, city.perim.m)
|
self.assertEqual(0, city.perim.m)
|
||||||
|
|
||||||
@skipUnlessDBFeature("has_Area_function", "has_Distance_function")
|
@skipUnlessDBFeature("supports_null_geometries", "has_Area_function", "has_Distance_function")
|
||||||
def test_measurement_null_fields(self):
|
def test_measurement_null_fields(self):
|
||||||
"""
|
"""
|
||||||
Test the measurement functions on fields with NULL values.
|
Test the measurement functions on fields with NULL values.
|
||||||
|
|
|
@ -9,7 +9,7 @@ from django.db import connection
|
||||||
from django.test import TestCase, skipUnlessDBFeature
|
from django.test import TestCase, skipUnlessDBFeature
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
from ..utils import oracle, postgis, spatialite
|
from ..utils import mysql, oracle, postgis, spatialite
|
||||||
|
|
||||||
if HAS_GEOS:
|
if HAS_GEOS:
|
||||||
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
from django.contrib.gis.geos import LineString, Point, Polygon, fromstr
|
||||||
|
@ -165,8 +165,8 @@ class GISFunctionsTests(TestCase):
|
||||||
@skipUnlessDBFeature("has_Centroid_function")
|
@skipUnlessDBFeature("has_Centroid_function")
|
||||||
def test_centroid(self):
|
def test_centroid(self):
|
||||||
qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly'))
|
qs = State.objects.exclude(poly__isnull=True).annotate(centroid=functions.Centroid('poly'))
|
||||||
|
tol = 1.8 if mysql else (0.1 if oracle else 0.00001)
|
||||||
for state in qs:
|
for state in qs:
|
||||||
tol = 0.1 # High tolerance due to oracle
|
|
||||||
self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol))
|
self.assertTrue(state.poly.centroid.equals_exact(state.centroid, tol))
|
||||||
|
|
||||||
@skipUnlessDBFeature("has_Difference_function")
|
@skipUnlessDBFeature("has_Difference_function")
|
||||||
|
@ -248,9 +248,9 @@ class GISFunctionsTests(TestCase):
|
||||||
qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point'))
|
qs = City.objects.filter(point__isnull=False).annotate(num_geom=functions.NumGeometries('point'))
|
||||||
for city in qs:
|
for city in qs:
|
||||||
# Oracle and PostGIS 2.0+ will return 1 for the number of
|
# Oracle and PostGIS 2.0+ will return 1 for the number of
|
||||||
# geometries on non-collections, whereas PostGIS < 2.0.0
|
# geometries on non-collections, whereas PostGIS < 2.0.0 and MySQL
|
||||||
# will return None.
|
# will return None.
|
||||||
if postgis and connection.ops.spatial_version < (2, 0, 0):
|
if (postgis and connection.ops.spatial_version < (2, 0, 0)) or mysql:
|
||||||
self.assertIsNone(city.num_geom)
|
self.assertIsNone(city.num_geom)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(1, city.num_geom)
|
self.assertEqual(1, city.num_geom)
|
||||||
|
@ -261,8 +261,8 @@ class GISFunctionsTests(TestCase):
|
||||||
Track.objects.create(name='Foo', line=LineString(coords))
|
Track.objects.create(name='Foo', line=LineString(coords))
|
||||||
qs = Track.objects.annotate(num_points=functions.NumPoints('line'))
|
qs = Track.objects.annotate(num_points=functions.NumPoints('line'))
|
||||||
self.assertEqual(qs.first().num_points, 2)
|
self.assertEqual(qs.first().num_points, 2)
|
||||||
if spatialite:
|
if spatialite or mysql:
|
||||||
# Spatialite can only count points on LineStrings
|
# Spatialite and MySQL can only count points on LineStrings
|
||||||
return
|
return
|
||||||
|
|
||||||
for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')):
|
for c in Country.objects.annotate(num_points=functions.NumPoints('mpoly')):
|
||||||
|
@ -455,5 +455,7 @@ class GISFunctionsTests(TestCase):
|
||||||
geom = Point(-95.363151, 29.763374, srid=4326)
|
geom = Point(-95.363151, 29.763374, srid=4326)
|
||||||
ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas')
|
ptown = City.objects.annotate(union=functions.Union('point', geom)).get(name='Dallas')
|
||||||
tol = 0.00001
|
tol = 0.00001
|
||||||
expected = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
|
# Undefined ordering
|
||||||
self.assertTrue(expected.equals_exact(ptown.union, tol))
|
expected1 = fromstr('MULTIPOINT(-96.801611 32.782057,-95.363151 29.763374)', srid=4326)
|
||||||
|
expected2 = fromstr('MULTIPOINT(-95.363151 29.763374,-96.801611 32.782057)', srid=4326)
|
||||||
|
self.assertTrue(expected1.equals_exact(ptown.union, tol) or expected2.equals_exact(ptown.union, tol))
|
||||||
|
|
Loading…
Reference in New Issue