diff --git a/django/contrib/gis/geos/libgeos.py b/django/contrib/gis/geos/libgeos.py index 78665894bce..df7e2aa5d58 100644 --- a/django/contrib/gis/geos/libgeos.py +++ b/django/contrib/gis/geos/libgeos.py @@ -9,13 +9,12 @@ import logging import os import re -import threading from ctypes import CDLL, CFUNCTYPE, POINTER, Structure, c_char_p from ctypes.util import find_library from django.contrib.gis.geos.error import GEOSException from django.core.exceptions import ImproperlyConfigured -from django.utils.functional import SimpleLazyObject, cached_property +from django.utils.functional import SimpleLazyObject logger = logging.getLogger('django.contrib.gis') @@ -64,11 +63,10 @@ def load_geos(): _lgeos = CDLL(lib_path) # Here we set up the prototypes for the initGEOS_r and finishGEOS_r # routines. These functions aren't actually called until they are - # attached to a GEOS context handle. + # attached to a GEOS context handle -- this actually occurs in + # geos/prototypes/threadsafe.py. _lgeos.initGEOS_r.restype = CONTEXT_PTR _lgeos.finishGEOS_r.argtypes = [CONTEXT_PTR] - # Ensures compatibility across 32 and 64-bit platforms. - _lgeos.GEOSversion.restype = c_char_p return _lgeos @@ -136,27 +134,6 @@ def get_pointer_arr(n): lgeos = SimpleLazyObject(load_geos) -class GEOSContextHandle(object): - def __init__(self): - # Initializing the context handle for this thread with - # the notice and error handler. - self.ptr = lgeos.initGEOS_r(notice_h, error_h) - - def __del__(self): - if self.ptr and lgeos: - lgeos.finishGEOS_r(self.ptr) - - -class GEOSContext(threading.local): - - @cached_property - def ptr(self): - # Assign handle so it will will garbage collected when - # thread is finished. - self.handle = GEOSContextHandle() - return self.handle.ptr - - class GEOSFuncFactory(object): """ Lazy loading of GEOS functions. @@ -164,7 +141,6 @@ class GEOSFuncFactory(object): argtypes = None restype = None errcheck = None - thread_context = GEOSContext() def __init__(self, func_name, *args, **kwargs): self.func_name = func_name @@ -178,23 +154,21 @@ class GEOSFuncFactory(object): def __call__(self, *args, **kwargs): if self.func is None: self.func = self.get_func(*self.args, **self.kwargs) - # Call the threaded GEOS routine with pointer of the context handle - # as the first argument. - return self.func(self.thread_context.ptr, *args) + return self.func(*args, **kwargs) def get_func(self, *args, **kwargs): - # GEOS thread-safe function signatures end with '_r' and - # take an additional context handle parameter. - func = getattr(lgeos, self.func_name + '_r') - func.argtypes = [CONTEXT_PTR] + (self.argtypes or []) + from django.contrib.gis.geos.prototypes.threadsafe import GEOSFunc + func = GEOSFunc(self.func_name) + func.argtypes = self.argtypes or [] func.restype = self.restype if self.errcheck: func.errcheck = self.errcheck return func -# Returns the string version of the GEOS library. -geos_version = lambda: lgeos.GEOSversion() +# Returns the string version of the GEOS library. Have to set the restype +# explicitly to c_char_p to ensure compatibility across 32 and 64-bit platforms. +geos_version = GEOSFuncFactory('GEOSversion', restype=c_char_p) # Regular expression should be able to parse version strings such as # '3.0.0rc4-CAPI-1.3.3', '3.0.0-CAPI-1.4.1', '3.4.0dev-CAPI-1.8.0' or '3.4.0dev-CAPI-1.8.0 r0' diff --git a/django/contrib/gis/geos/prototypes/threadsafe.py b/django/contrib/gis/geos/prototypes/threadsafe.py new file mode 100644 index 00000000000..45c87d60045 --- /dev/null +++ b/django/contrib/gis/geos/prototypes/threadsafe.py @@ -0,0 +1,93 @@ +import threading + +from django.contrib.gis.geos.libgeos import ( + CONTEXT_PTR, error_h, lgeos, notice_h, +) + + +class GEOSContextHandle(object): + """ + Python object representing a GEOS context handle. + """ + def __init__(self): + # Initializing the context handler for this thread with + # the notice and error handler. + self.ptr = lgeos.initGEOS_r(notice_h, error_h) + + def __del__(self): + if self.ptr and lgeos: + lgeos.finishGEOS_r(self.ptr) + + +# Defining a thread-local object and creating an instance +# to hold a reference to GEOSContextHandle for this thread. +class GEOSContext(threading.local): + handle = None + +thread_context = GEOSContext() + + +class GEOSFunc(object): + """ + Class that serves as a wrapper for GEOS C Functions, and will + use thread-safe function variants when available. + """ + def __init__(self, func_name): + try: + # GEOS thread-safe function signatures end with '_r', and + # take an additional context handle parameter. + self.cfunc = getattr(lgeos, func_name + '_r') + self.threaded = True + # Create a reference here to thread_context so it's not + # garbage-collected before an attempt to call this object. + self.thread_context = thread_context + except AttributeError: + # Otherwise, use usual function. + self.cfunc = getattr(lgeos, func_name) + self.threaded = False + + def __call__(self, *args): + if self.threaded: + # If a context handle does not exist for this thread, initialize one. + if not self.thread_context.handle: + self.thread_context.handle = GEOSContextHandle() + # Call the threaded GEOS routine with pointer of the context handle + # as the first argument. + return self.cfunc(self.thread_context.handle.ptr, *args) + else: + return self.cfunc(*args) + + def __str__(self): + return self.cfunc.__name__ + + # argtypes property + def _get_argtypes(self): + return self.cfunc.argtypes + + def _set_argtypes(self, argtypes): + if self.threaded: + new_argtypes = [CONTEXT_PTR] + new_argtypes.extend(argtypes) + self.cfunc.argtypes = new_argtypes + else: + self.cfunc.argtypes = argtypes + + argtypes = property(_get_argtypes, _set_argtypes) + + # restype property + def _get_restype(self): + return self.cfunc.restype + + def _set_restype(self, restype): + self.cfunc.restype = restype + + restype = property(_get_restype, _set_restype) + + # errcheck property + def _get_errcheck(self): + return self.cfunc.errcheck + + def _set_errcheck(self, errcheck): + self.cfunc.errcheck = errcheck + + errcheck = property(_get_errcheck, _set_errcheck) diff --git a/tests/gis_tests/geos_tests/test_geos.py b/tests/gis_tests/geos_tests/test_geos.py index 8c67139e92f..116e8d35f8c 100644 --- a/tests/gis_tests/geos_tests/test_geos.py +++ b/tests/gis_tests/geos_tests/test_geos.py @@ -3,7 +3,6 @@ from __future__ import unicode_literals import ctypes import json import random -import threading from binascii import a2b_hex, b2a_hex from io import BytesIO from unittest import skipUnless @@ -13,7 +12,7 @@ from django.contrib.gis.gdal import HAS_GDAL from django.contrib.gis.geos import ( HAS_GEOS, GeometryCollection, GEOSException, GEOSGeometry, LinearRing, LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon, - fromfile, fromstr, libgeos, + fromfile, fromstr, ) from django.contrib.gis.geos.base import GEOSBase from django.contrib.gis.geos.libgeos import geos_version_info @@ -1233,48 +1232,6 @@ class GEOSTest(SimpleTestCase, TestDataMixin): self.assertEqual(m.group('version'), v_geos) self.assertEqual(m.group('capi_version'), v_capi) - def test_geos_threads(self): - pnt = Point() - context_ptrs = [] - - geos_init = libgeos.lgeos.initGEOS_r - geos_finish = libgeos.lgeos.finishGEOS_r - - def init(*args, **kwargs): - result = geos_init(*args, **kwargs) - context_ptrs.append(result) - return result - - def finish(*args, **kwargs): - result = geos_finish(*args, **kwargs) - destructor_called.set() - return result - - for i in range(2): - destructor_called = threading.Event() - patch_path = 'django.contrib.gis.geos.libgeos.lgeos' - with mock.patch.multiple(patch_path, initGEOS_r=mock.DEFAULT, finishGEOS_r=mock.DEFAULT) as mocked: - mocked['initGEOS_r'].side_effect = init - mocked['finishGEOS_r'].side_effect = finish - with mock.patch('django.contrib.gis.geos.prototypes.predicates.geos_hasz.func') as mocked_hasz: - thread = threading.Thread(target=lambda: pnt.hasz) - thread.start() - thread.join() - - # We can't be sure that members of thread locals are - # garbage collected right after `thread.join()` so - # we must wait until destructor is actually called. - # Fail if destructor wasn't called within a second. - self.assertTrue(destructor_called.wait(1)) - - context_ptr = context_ptrs[i] - self.assertIsInstance(context_ptr, libgeos.CONTEXT_PTR) - mocked_hasz.assert_called_once_with(context_ptr, pnt.ptr) - mocked['finishGEOS_r'].assert_called_once_with(context_ptr) - - # Check that different contexts were used for the different threads. - self.assertNotEqual(context_ptrs[0], context_ptrs[1]) - @ignore_warnings(category=RemovedInDjango20Warning) def test_deprecated_srid_getters_setters(self): p = Point(1, 2, srid=123)