diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 87d9ba41c2..7ddb2ed708 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -376,7 +376,7 @@ class LineLocatePoint(GeoFunc): geom_param_pos = (0, 1) -class MakeValid(GeoFunc): +class MakeValid(GeomOutputGeoFunc): pass diff --git a/tests/gis_tests/geoapp/test_functions.py b/tests/gis_tests/geoapp/test_functions.py index 18746f48df..1b1b5c7b6d 100644 --- a/tests/gis_tests/geoapp/test_functions.py +++ b/tests/gis_tests/geoapp/test_functions.py @@ -3,13 +3,13 @@ import math import re from decimal import Decimal -from django.contrib.gis.db.models import functions +from django.contrib.gis.db.models import GeometryField, PolygonField, functions from django.contrib.gis.geos import ( GEOSGeometry, LineString, Point, Polygon, fromstr, ) from django.contrib.gis.measure import Area from django.db import NotSupportedError, connection -from django.db.models import Sum +from django.db.models import Sum, Value from django.test import TestCase, skipUnlessDBFeature from ..utils import FuncTestMixin, mariadb, mysql, oracle, postgis, spatialite @@ -348,6 +348,34 @@ class GISFunctionsTests(FuncTestMixin, TestCase): self.assertIs(invalid.repaired.valid, True) self.assertEqual(invalid.repaired, fromstr('POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))', srid=invalid.poly.srid)) + @skipUnlessDBFeature('has_MakeValid_function') + def test_make_valid_multipolygon(self): + invalid_geom = fromstr( + 'POLYGON((0 0, 0 1 , 1 1 , 1 0, 0 0), ' + '(10 0, 10 1, 11 1, 11 0, 10 0))' + ) + State.objects.create(name='invalid', poly=invalid_geom) + invalid = State.objects.filter(name='invalid').annotate( + repaired=functions.MakeValid('poly'), + ).get() + self.assertIs(invalid.repaired.valid, True) + self.assertEqual(invalid.repaired, fromstr( + 'MULTIPOLYGON (((0 0, 0 1, 1 1, 1 0, 0 0)), ' + '((10 0, 10 1, 11 1, 11 0, 10 0)))', + srid=invalid.poly.srid, + )) + self.assertEqual(len(invalid.repaired), 2) + + @skipUnlessDBFeature('has_MakeValid_function') + def test_make_valid_output_field(self): + # output_field is GeometryField instance because different geometry + # types can be returned. + output_field = functions.MakeValid( + Value(Polygon(), PolygonField(srid=42)), + ).output_field + self.assertIs(output_field.__class__, GeometryField) + self.assertEqual(output_field.srid, 42) + @skipUnlessDBFeature("has_MemSize_function") def test_memsize(self): ptown = City.objects.annotate(size=functions.MemSize('point')).get(name='Pueblo')