From 2ef4b4795e29be8c33a6de9cc0c05b59025d13a5 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Tue, 5 Sep 2017 18:54:57 +0500 Subject: [PATCH] Refs #28518 -- Improved performance of loading geometries from DB. --- .../gis/db/backends/mysql/operations.py | 9 +- .../gis/db/backends/oracle/operations.py | 9 +- .../gis/db/backends/postgis/operations.py | 8 +- .../gis/db/backends/spatialite/operations.py | 16 +- django/contrib/gis/db/models/aggregates.py | 12 +- django/contrib/gis/db/models/fields.py | 12 ++ django/contrib/gis/db/models/functions.py | 2 +- django/contrib/gis/geos/geometry.py | 194 +++++++++--------- django/contrib/gis/geos/linestring.py | 2 +- django/contrib/gis/geos/point.py | 2 +- 10 files changed, 156 insertions(+), 110 deletions(-) diff --git a/django/contrib/gis/db/backends/mysql/operations.py b/django/contrib/gis/db/backends/mysql/operations.py index fd3db2c7a45..12f9eff9559 100644 --- a/django/contrib/gis/db/backends/mysql/operations.py +++ b/django/contrib/gis/db/backends/mysql/operations.py @@ -4,7 +4,7 @@ from django.contrib.gis.db.backends.base.operations import ( ) from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.models import aggregates -from django.contrib.gis.geos import GEOSGeometry +from django.contrib.gis.geos.geometry import GEOSGeometryBase from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.measure import Distance from django.db.backends.mysql.operations import DatabaseOperations @@ -100,7 +100,12 @@ class MySQLOperations(BaseSpatialOperations, DatabaseOperations): srid = expression.output_field.srid if srid == -1: srid = None + geom_class = expression.output_field.geom_class def converter(value, expression, connection): - return None if value is None else GEOSGeometry(read(memoryview(value)), srid) + if value is not None: + geom = GEOSGeometryBase(read(memoryview(value)), geom_class) + if srid: + geom.srid = srid + return geom return converter diff --git a/django/contrib/gis/db/backends/oracle/operations.py b/django/contrib/gis/db/backends/oracle/operations.py index b23cfc3ab13..e69b8b8c071 100644 --- a/django/contrib/gis/db/backends/oracle/operations.py +++ b/django/contrib/gis/db/backends/oracle/operations.py @@ -16,7 +16,7 @@ from django.contrib.gis.db.backends.oracle.adapter import OracleSpatialAdapter from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.models import aggregates from django.contrib.gis.geometry.backend import Geometry -from django.contrib.gis.geos import GEOSGeometry +from django.contrib.gis.geos.geometry import GEOSGeometryBase from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.measure import Distance from django.db.backends.oracle.operations import DatabaseOperations @@ -203,7 +203,12 @@ class OracleOperations(BaseSpatialOperations, DatabaseOperations): srid = expression.output_field.srid if srid == -1: srid = None + geom_class = expression.output_field.geom_class def converter(value, expression, connection): - return None if value is None else GEOSGeometry(read(memoryview(value.read())), srid) + if value is not None: + geom = GEOSGeometryBase(read(memoryview(value.read())), geom_class) + if srid: + geom.srid = srid + return geom return converter diff --git a/django/contrib/gis/db/backends/postgis/operations.py b/django/contrib/gis/db/backends/postgis/operations.py index 9199d16e2b8..40860b3d0e1 100644 --- a/django/contrib/gis/db/backends/postgis/operations.py +++ b/django/contrib/gis/db/backends/postgis/operations.py @@ -7,7 +7,7 @@ from django.contrib.gis.db.backends.base.operations import ( from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.models import GeometryField, RasterField from django.contrib.gis.gdal import GDALRaster -from django.contrib.gis.geos import GEOSGeometry +from django.contrib.gis.geos.geometry import GEOSGeometryBase from django.contrib.gis.geos.prototypes.io import wkb_r from django.contrib.gis.measure import Distance from django.core.exceptions import ImproperlyConfigured @@ -392,4 +392,8 @@ class PostGISOperations(BaseSpatialOperations, DatabaseOperations): def get_geometry_converter(self, expression): read = wkb_r().read - return lambda value, expression, connection: None if value is None else GEOSGeometry(read(value)) + geom_class = expression.output_field.geom_class + + def converter(value, expression, connection): + return None if value is None else GEOSGeometryBase(read(value), geom_class) + return converter diff --git a/django/contrib/gis/db/backends/spatialite/operations.py b/django/contrib/gis/db/backends/spatialite/operations.py index 536834ec8de..69dc3b99f63 100644 --- a/django/contrib/gis/db/backends/spatialite/operations.py +++ b/django/contrib/gis/db/backends/spatialite/operations.py @@ -10,7 +10,7 @@ from django.contrib.gis.db.backends.spatialite.adapter import SpatiaLiteAdapter from django.contrib.gis.db.backends.utils import SpatialOperator from django.contrib.gis.db.models import aggregates from django.contrib.gis.geometry.backend import Geometry -from django.contrib.gis.geos import GEOSGeometry +from django.contrib.gis.geos.geometry import GEOSGeometryBase from django.contrib.gis.geos.prototypes.io import wkb_r, wkt_r from django.contrib.gis.measure import Distance from django.core.exceptions import ImproperlyConfigured @@ -199,12 +199,22 @@ class SpatiaLiteOperations(BaseSpatialOperations, DatabaseOperations): return SpatialiteSpatialRefSys def get_geometry_converter(self, expression): + geom_class = expression.output_field.geom_class if self.spatial_version >= (4, 3, 0): read = wkb_r().read - return lambda value, expression, connection: None if value is None else GEOSGeometry(read(value)) + + def converter(value, expression, connection): + return None if value is None else GEOSGeometryBase(read(value), geom_class) else: read = wkt_r().read srid = expression.output_field.srid if srid == -1: srid = None - return lambda value, expression, connection: None if value is None else GEOSGeometry(read(value), srid) + + def converter(value, expression, connection): + if value is not None: + geom = GEOSGeometryBase(read(value), geom_class) + if srid: + geom.srid = srid + return geom + return converter diff --git a/django/contrib/gis/db/models/aggregates.py b/django/contrib/gis/db/models/aggregates.py index 3b5a89cafd1..e61bb7207df 100644 --- a/django/contrib/gis/db/models/aggregates.py +++ b/django/contrib/gis/db/models/aggregates.py @@ -1,5 +1,8 @@ -from django.contrib.gis.db.models.fields import ExtentField +from django.contrib.gis.db.models.fields import ( + ExtentField, GeometryCollectionField, GeometryField, LineStringField, +) from django.db.models.aggregates import Aggregate +from django.utils.functional import cached_property __all__ = ['Collect', 'Extent', 'Extent3D', 'MakeLine', 'Union'] @@ -8,6 +11,10 @@ class GeoAggregate(Aggregate): function = None is_extent = False + @cached_property + def output_field(self): + return self.output_field_class(self.source_expressions[0].output_field.srid) + def as_sql(self, compiler, connection, function=None, **extra_context): # this will be called again in parent, but it's needed now - before # we get the spatial_aggregate_name @@ -34,6 +41,7 @@ class GeoAggregate(Aggregate): class Collect(GeoAggregate): name = 'Collect' + output_field_class = GeometryCollectionField class Extent(GeoAggregate): @@ -60,7 +68,9 @@ class Extent3D(GeoAggregate): class MakeLine(GeoAggregate): name = 'MakeLine' + output_field_class = LineStringField class Union(GeoAggregate): name = 'Union' + output_field_class = GeometryField diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index b571f9977a1..8265815c306 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -5,6 +5,10 @@ from django.contrib.gis import forms, gdal from django.contrib.gis.db.models.proxy import SpatialProxy from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.geometry.backend import Geometry, GeometryException +from django.contrib.gis.geos import ( + GeometryCollection, LineString, MultiLineString, MultiPoint, MultiPolygon, + Point, Polygon, +) from django.core.exceptions import ImproperlyConfigured from django.db.models.fields import Field from django.utils.translation import gettext_lazy as _ @@ -196,6 +200,7 @@ class GeometryField(BaseSpatialField): form_class = forms.GeometryField # The OpenGIS Geometry name. geom_type = 'GEOMETRY' + geom_class = None def __init__(self, verbose_name=None, dim=2, geography=False, *, extent=(-180.0, -90.0, 180.0, 90.0), tolerance=0.05, **kwargs): @@ -268,42 +273,49 @@ class GeometryField(BaseSpatialField): # The OpenGIS Geometry Type Fields class PointField(GeometryField): geom_type = 'POINT' + geom_class = Point form_class = forms.PointField description = _("Point") class LineStringField(GeometryField): geom_type = 'LINESTRING' + geom_class = LineString form_class = forms.LineStringField description = _("Line string") class PolygonField(GeometryField): geom_type = 'POLYGON' + geom_class = Polygon form_class = forms.PolygonField description = _("Polygon") class MultiPointField(GeometryField): geom_type = 'MULTIPOINT' + geom_class = MultiPoint form_class = forms.MultiPointField description = _("Multi-point") class MultiLineStringField(GeometryField): geom_type = 'MULTILINESTRING' + geom_class = MultiLineString form_class = forms.MultiLineStringField description = _("Multi-line string") class MultiPolygonField(GeometryField): geom_type = 'MULTIPOLYGON' + geom_class = MultiPolygon form_class = forms.MultiPolygonField description = _("Multi polygon") class GeometryCollectionField(GeometryField): geom_type = 'GEOMETRYCOLLECTION' + geom_class = GeometryCollection form_class = forms.GeometryCollectionField description = _("Geometry collection") diff --git a/django/contrib/gis/db/models/functions.py b/django/contrib/gis/db/models/functions.py index 510d54408ca..e166c21e295 100644 --- a/django/contrib/gis/db/models/functions.py +++ b/django/contrib/gis/db/models/functions.py @@ -97,7 +97,7 @@ class GeoFunc(GeoFuncMixin, Func): class GeomOutputGeoFunc(GeoFunc): @cached_property def output_field(self): - return self.geo_field + return GeometryField(srid=self.geo_field.srid) class SQLiteDecimalToFloatMixin: diff --git a/django/contrib/gis/geos/geometry.py b/django/contrib/gis/geos/geometry.py index 9a8f8ed71f6..325b66f6b8b 100644 --- a/django/contrib/gis/geos/geometry.py +++ b/django/contrib/gis/geos/geometry.py @@ -21,9 +21,7 @@ from django.utils.deconstruct import deconstructible from django.utils.encoding import force_bytes, force_text -@deconstructible -class GEOSGeometry(GEOSBase, ListMixin): - "A class that, generally, encapsulates a GEOS geometry." +class GEOSGeometryBase(GEOSBase): _GEOS_CLASSES = None @@ -31,96 +29,39 @@ class GEOSGeometry(GEOSBase, ListMixin): destructor = capi.destroy_geom has_cs = False # Only Point, LineString, LinearRing have coordinate sequences - def __init__(self, geo_input, srid=None): - """ - The base constructor for GEOS geometry objects, and may take the - following inputs: - - * strings: - - WKT - - HEXEWKB (a PostGIS-specific canonical form) - - GeoJSON (requires GDAL) - * buffer: - - WKB - - The `srid` keyword is used to specify the Source Reference Identifier - (SRID) number for this Geometry. If not set, the SRID will be None. - """ - input_srid = None - if isinstance(geo_input, bytes): - geo_input = force_text(geo_input) - if isinstance(geo_input, str): - wkt_m = wkt_regex.match(geo_input) - if wkt_m: - # Handling WKT input. - if wkt_m.group('srid'): - input_srid = int(wkt_m.group('srid')) - g = self._from_wkt(force_bytes(wkt_m.group('wkt'))) - elif hex_regex.match(geo_input): - # Handling HEXEWKB input. - g = wkb_r().read(force_bytes(geo_input)) - elif json_regex.match(geo_input): - # Handling GeoJSON input. - ogr = gdal.OGRGeometry.from_json(geo_input) - g = ogr._geos_ptr() - input_srid = ogr.srid - else: - raise ValueError('String input unrecognized as WKT EWKT, and HEXEWKB.') - elif isinstance(geo_input, GEOM_PTR): - # When the input is a pointer to a geometry (GEOM_PTR). - g = geo_input - elif isinstance(geo_input, memoryview): - # When the input is a buffer (WKB). - g = wkb_r().read(geo_input) - elif isinstance(geo_input, GEOSGeometry): - g = capi.geom_clone(geo_input.ptr) - else: - # Invalid geometry type. - raise TypeError('Improper geometry input type: %s' % type(geo_input)) - - if not g: - raise GEOSException('Could not initialize GEOS Geometry with given input.') - - input_srid = input_srid or capi.geos_get_srid(g) or None - if input_srid and srid and input_srid != srid: - raise ValueError('Input geometry already has SRID: %d.' % input_srid) - - # Setting the pointer object with a valid pointer. - self.ptr = g - # Post-initialization setup. - self._post_init(input_srid or srid) - - def _post_init(self, srid): - "Perform post-initialization setup." - # Setting the SRID, if given. - if srid and isinstance(srid, int): - self.srid = srid + def __init__(self, ptr, cls): + self._ptr = ptr # Setting the class type (e.g., Point, Polygon, etc.) - if type(self) == GEOSGeometry: - if GEOSGeometry._GEOS_CLASSES is None: - # Lazy-loaded variable to avoid import conflicts with GEOSGeometry. - from .linestring import LineString, LinearRing - from .point import Point - from .polygon import Polygon - from .collections import ( - GeometryCollection, MultiPoint, MultiLineString, MultiPolygon, - ) - GEOSGeometry._GEOS_CLASSES = { - 0: Point, - 1: LineString, - 2: LinearRing, - 3: Polygon, - 4: MultiPoint, - 5: MultiLineString, - 6: MultiPolygon, - 7: GeometryCollection, - } - self.__class__ = GEOSGeometry._GEOS_CLASSES[self.geom_typeid] + if type(self) in (GEOSGeometryBase, GEOSGeometry): + if cls is None: + if GEOSGeometryBase._GEOS_CLASSES is None: + # Inner imports avoid import conflicts with GEOSGeometry. + from .linestring import LineString, LinearRing + from .point import Point + from .polygon import Polygon + from .collections import ( + GeometryCollection, MultiPoint, MultiLineString, MultiPolygon, + ) + GEOSGeometryBase._GEOS_CLASSES = { + 0: Point, + 1: LineString, + 2: LinearRing, + 3: Polygon, + 4: MultiPoint, + 5: MultiLineString, + 6: MultiPolygon, + 7: GeometryCollection, + } + cls = GEOSGeometryBase._GEOS_CLASSES[self.geom_typeid] + self.__class__ = cls + self._post_init() + def _post_init(self): + "Perform post-initialization setup." # Setting the coordinate sequence for the geometry (will be None on # geometries that do not have coordinate sequences) - self._set_cs() + self._cs = GEOSCoordSeq(capi.get_cs(self.ptr), self.hasz) if self.has_cs else None def __copy__(self): """ @@ -158,7 +99,8 @@ class GEOSGeometry(GEOSBase, ListMixin): if not ptr: raise GEOSException('Invalid Geometry loaded from pickled state.') self.ptr = ptr - self._post_init(srid) + self._post_init() + self.srid = srid @classmethod def _from_wkb(cls, wkb): @@ -226,13 +168,6 @@ class GEOSGeometry(GEOSBase, ListMixin): return self.sym_difference(other) # #### Coordinate Sequence Routines #### - def _set_cs(self): - "Set the coordinate sequence for this Geometry." - if self.has_cs: - self._cs = GEOSCoordSeq(capi.get_cs(self.ptr), self.hasz) - else: - self._cs = None - @property def coord_seq(self): "Return a clone of the coordinate sequence for this Geometry." @@ -536,7 +471,8 @@ class GEOSGeometry(GEOSBase, ListMixin): # again due to the reassignment. capi.destroy_geom(self.ptr) self.ptr = ptr - self._post_init(g.srid) + self._post_init() + self.srid = g.srid else: raise GEOSException('Transformed WKB was invalid.') @@ -715,3 +651,67 @@ class LinearGeometryMixin: Return whether or not this Geometry is closed. """ return capi.geos_isclosed(self.ptr) + + +@deconstructible +class GEOSGeometry(GEOSGeometryBase, ListMixin): + "A class that, generally, encapsulates a GEOS geometry." + + def __init__(self, geo_input, srid=None): + """ + The base constructor for GEOS geometry objects. It may take the + following inputs: + + * strings: + - WKT + - HEXEWKB (a PostGIS-specific canonical form) + - GeoJSON (requires GDAL) + * buffer: + - WKB + + The `srid` keyword specifies the Source Reference Identifier (SRID) + number for this Geometry. If not provided, it defaults to None. + """ + input_srid = None + if isinstance(geo_input, bytes): + geo_input = force_text(geo_input) + if isinstance(geo_input, str): + wkt_m = wkt_regex.match(geo_input) + if wkt_m: + # Handle WKT input. + if wkt_m.group('srid'): + input_srid = int(wkt_m.group('srid')) + g = self._from_wkt(force_bytes(wkt_m.group('wkt'))) + elif hex_regex.match(geo_input): + # Handle HEXEWKB input. + g = wkb_r().read(force_bytes(geo_input)) + elif json_regex.match(geo_input): + # Handle GeoJSON input. + ogr = gdal.OGRGeometry.from_json(geo_input) + g = ogr._geos_ptr() + input_srid = ogr.srid + else: + raise ValueError('String input unrecognized as WKT EWKT, and HEXEWKB.') + elif isinstance(geo_input, GEOM_PTR): + # When the input is a pointer to a geometry (GEOM_PTR). + g = geo_input + elif isinstance(geo_input, memoryview): + # When the input is a buffer (WKB). + g = wkb_r().read(geo_input) + elif isinstance(geo_input, GEOSGeometry): + g = capi.geom_clone(geo_input.ptr) + else: + raise TypeError('Improper geometry input type: %s' % type(geo_input)) + + if not g: + raise GEOSException('Could not initialize GEOS Geometry with given input.') + + input_srid = input_srid or capi.geos_get_srid(g) or None + if input_srid and srid and input_srid != srid: + raise ValueError('Input geometry already has SRID: %d.' % input_srid) + + super().__init__(g, None) + # Set the SRID, if given. + srid = input_srid or srid + if srid and isinstance(srid, int): + self.srid = srid diff --git a/django/contrib/gis/geos/linestring.py b/django/contrib/gis/geos/linestring.py index 2b909af7b82..54808a7d3b5 100644 --- a/django/contrib/gis/geos/linestring.py +++ b/django/contrib/gis/geos/linestring.py @@ -116,7 +116,7 @@ class LineString(LinearGeometryMixin, GEOSGeometry): if ptr: capi.destroy_geom(self.ptr) self.ptr = ptr - self._post_init(self.srid) + self._post_init() else: # can this happen? raise GEOSException('Geometry resulting from slice deletion was invalid.') diff --git a/django/contrib/gis/geos/point.py b/django/contrib/gis/geos/point.py index d46844248c3..ccf5b9dbaf2 100644 --- a/django/contrib/gis/geos/point.py +++ b/django/contrib/gis/geos/point.py @@ -72,7 +72,7 @@ class Point(GEOSGeometry): if ptr: capi.destroy_geom(self.ptr) self._ptr = ptr - self._set_cs() + self._post_init() else: # can this happen? raise GEOSException('Geometry resulting from slice deletion was invalid.')