diff --git a/django/contrib/gis/geos/geometry.py b/django/contrib/gis/geos/geometry.py index 1a13cdbbc8..07a22cb751 100644 --- a/django/contrib/gis/geos/geometry.py +++ b/django/contrib/gis/geos/geometry.py @@ -95,24 +95,26 @@ class GEOSGeometry(GEOSBase, ListMixin): self.srid = srid # Setting the class type (e.g., Point, Polygon, etc.) - if GEOSGeometry._GEOS_CLASSES is None: - # Lazy-loaded variable to avoid import conflicts with GEOSGeometry. - from .linestring import LineString, LinearRing - from .point import Point - from .polygon import Polygon - from .collections import ( - GeometryCollection, MultiPoint, MultiLineString, MultiPolygon) - GEOSGeometry._GEOS_CLASSES = { - 0: Point, - 1: LineString, - 2: LinearRing, - 3: Polygon, - 4: MultiPoint, - 5: MultiLineString, - 6: MultiPolygon, - 7: GeometryCollection, - } - self.__class__ = GEOSGeometry._GEOS_CLASSES[self.geom_typeid] + if type(self) == GEOSGeometry: + if GEOSGeometry._GEOS_CLASSES is None: + # Lazy-loaded variable to avoid import conflicts with GEOSGeometry. + from .linestring import LineString, LinearRing + from .point import Point + from .polygon import Polygon + from .collections import ( + GeometryCollection, MultiPoint, MultiLineString, MultiPolygon, + ) + GEOSGeometry._GEOS_CLASSES = { + 0: Point, + 1: LineString, + 2: LinearRing, + 3: Polygon, + 4: MultiPoint, + 5: MultiLineString, + 6: MultiPolygon, + 7: GeometryCollection, + } + self.__class__ = GEOSGeometry._GEOS_CLASSES[self.geom_typeid] # Setting the coordinate sequence for the geometry (will be None on # geometries that do not have coordinate sequences) diff --git a/tests/gis_tests/geos_tests/test_geos.py b/tests/gis_tests/geos_tests/test_geos.py index 544bd3f48f..9bb55a1b4d 100644 --- a/tests/gis_tests/geos_tests/test_geos.py +++ b/tests/gis_tests/geos_tests/test_geos.py @@ -1308,6 +1308,25 @@ class GEOSTest(SimpleTestCase, TestDataMixin): self.assertEqual(args, (Point(0, 0), MultiPoint(Point(0, 0), Point(1, 1)), poly)) self.assertEqual(kwargs, {}) + def test_subclassing(self): + """ + GEOSGeometry subclass may itself be subclassed without being forced-cast + to the parent class during `__init__`. + """ + class ExtendedPolygon(Polygon): + def __init__(self, *args, **kwargs): + data = kwargs.pop('data', 0) + super(ExtendedPolygon, self).__init__(*args, **kwargs) + self._data = data + + def __str__(self): + return "EXT_POLYGON - data: %d - %s" % (self._data, self.wkt) + + ext_poly = ExtendedPolygon(((0, 0), (0, 1), (1, 1), (0, 0)), data=3) + self.assertEqual(type(ext_poly), ExtendedPolygon) + # ExtendedPolygon.__str__ should be called (instead of Polygon.__str__). + self.assertEqual(str(ext_poly), "EXT_POLYGON - data: 3 - POLYGON ((0 0, 0 1, 1 1, 0 0))") + def test_geos_version(self): """Testing the GEOS version regular expression.""" from django.contrib.gis.geos.libgeos import version_regex