Fixed #25961 -- Removed handling of thread-non-safe GEOS functions.

This commit is contained in:
Sergey Fedoseev 2015-12-23 14:35:17 +05:00 committed by Tim Graham
parent febe1321da
commit 312fc1af7b
3 changed files with 80 additions and 104 deletions

View File

@ -9,12 +9,13 @@
import logging import logging
import os import os
import re import re
import threading
from ctypes import CDLL, CFUNCTYPE, POINTER, Structure, c_char_p from ctypes import CDLL, CFUNCTYPE, POINTER, Structure, c_char_p
from ctypes.util import find_library from ctypes.util import find_library
from django.contrib.gis.geos.error import GEOSException from django.contrib.gis.geos.error import GEOSException
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.utils.functional import SimpleLazyObject from django.utils.functional import SimpleLazyObject, cached_property
logger = logging.getLogger('django.contrib.gis') logger = logging.getLogger('django.contrib.gis')
@ -63,10 +64,11 @@ def load_geos():
_lgeos = CDLL(lib_path) _lgeos = CDLL(lib_path)
# Here we set up the prototypes for the initGEOS_r and finishGEOS_r # Here we set up the prototypes for the initGEOS_r and finishGEOS_r
# routines. These functions aren't actually called until they are # routines. These functions aren't actually called until they are
# attached to a GEOS context handle -- this actually occurs in # attached to a GEOS context handle.
# geos/prototypes/threadsafe.py.
_lgeos.initGEOS_r.restype = CONTEXT_PTR _lgeos.initGEOS_r.restype = CONTEXT_PTR
_lgeos.finishGEOS_r.argtypes = [CONTEXT_PTR] _lgeos.finishGEOS_r.argtypes = [CONTEXT_PTR]
# Ensures compatibility across 32 and 64-bit platforms.
_lgeos.GEOSversion.restype = c_char_p
return _lgeos return _lgeos
@ -134,6 +136,27 @@ def get_pointer_arr(n):
lgeos = SimpleLazyObject(load_geos) lgeos = SimpleLazyObject(load_geos)
class GEOSContextHandle(object):
def __init__(self):
# Initializing the context handle for this thread with
# the notice and error handler.
self.ptr = lgeos.initGEOS_r(notice_h, error_h)
def __del__(self):
if self.ptr and lgeos:
lgeos.finishGEOS_r(self.ptr)
class GEOSContext(threading.local):
@cached_property
def ptr(self):
# Assign handle so it will will garbage collected when
# thread is finished.
self.handle = GEOSContextHandle()
return self.handle.ptr
class GEOSFuncFactory(object): class GEOSFuncFactory(object):
""" """
Lazy loading of GEOS functions. Lazy loading of GEOS functions.
@ -141,6 +164,7 @@ class GEOSFuncFactory(object):
argtypes = None argtypes = None
restype = None restype = None
errcheck = None errcheck = None
thread_context = GEOSContext()
def __init__(self, func_name, *args, **kwargs): def __init__(self, func_name, *args, **kwargs):
self.func_name = func_name self.func_name = func_name
@ -154,21 +178,23 @@ class GEOSFuncFactory(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if self.func is None: if self.func is None:
self.func = self.get_func(*self.args, **self.kwargs) self.func = self.get_func(*self.args, **self.kwargs)
return self.func(*args, **kwargs) # Call the threaded GEOS routine with pointer of the context handle
# as the first argument.
return self.func(self.thread_context.ptr, *args)
def get_func(self, *args, **kwargs): def get_func(self, *args, **kwargs):
from django.contrib.gis.geos.prototypes.threadsafe import GEOSFunc # GEOS thread-safe function signatures end with '_r' and
func = GEOSFunc(self.func_name) # take an additional context handle parameter.
func.argtypes = self.argtypes or [] func = getattr(lgeos, self.func_name + '_r')
func.argtypes = [CONTEXT_PTR] + (self.argtypes or [])
func.restype = self.restype func.restype = self.restype
if self.errcheck: if self.errcheck:
func.errcheck = self.errcheck func.errcheck = self.errcheck
return func return func
# Returns the string version of the GEOS library. Have to set the restype # Returns the string version of the GEOS library.
# explicitly to c_char_p to ensure compatibility across 32 and 64-bit platforms. geos_version = lambda: lgeos.GEOSversion()
geos_version = GEOSFuncFactory('GEOSversion', restype=c_char_p)
# Regular expression should be able to parse version strings such as # Regular expression should be able to parse version strings such as
# '3.0.0rc4-CAPI-1.3.3', '3.0.0-CAPI-1.4.1', '3.4.0dev-CAPI-1.8.0' or '3.4.0dev-CAPI-1.8.0 r0' # '3.0.0rc4-CAPI-1.3.3', '3.0.0-CAPI-1.4.1', '3.4.0dev-CAPI-1.8.0' or '3.4.0dev-CAPI-1.8.0 r0'

View File

@ -1,93 +0,0 @@
import threading
from django.contrib.gis.geos.libgeos import (
CONTEXT_PTR, error_h, lgeos, notice_h,
)
class GEOSContextHandle(object):
"""
Python object representing a GEOS context handle.
"""
def __init__(self):
# Initializing the context handler for this thread with
# the notice and error handler.
self.ptr = lgeos.initGEOS_r(notice_h, error_h)
def __del__(self):
if self.ptr and lgeos:
lgeos.finishGEOS_r(self.ptr)
# Defining a thread-local object and creating an instance
# to hold a reference to GEOSContextHandle for this thread.
class GEOSContext(threading.local):
handle = None
thread_context = GEOSContext()
class GEOSFunc(object):
"""
Class that serves as a wrapper for GEOS C Functions, and will
use thread-safe function variants when available.
"""
def __init__(self, func_name):
try:
# GEOS thread-safe function signatures end with '_r', and
# take an additional context handle parameter.
self.cfunc = getattr(lgeos, func_name + '_r')
self.threaded = True
# Create a reference here to thread_context so it's not
# garbage-collected before an attempt to call this object.
self.thread_context = thread_context
except AttributeError:
# Otherwise, use usual function.
self.cfunc = getattr(lgeos, func_name)
self.threaded = False
def __call__(self, *args):
if self.threaded:
# If a context handle does not exist for this thread, initialize one.
if not self.thread_context.handle:
self.thread_context.handle = GEOSContextHandle()
# Call the threaded GEOS routine with pointer of the context handle
# as the first argument.
return self.cfunc(self.thread_context.handle.ptr, *args)
else:
return self.cfunc(*args)
def __str__(self):
return self.cfunc.__name__
# argtypes property
def _get_argtypes(self):
return self.cfunc.argtypes
def _set_argtypes(self, argtypes):
if self.threaded:
new_argtypes = [CONTEXT_PTR]
new_argtypes.extend(argtypes)
self.cfunc.argtypes = new_argtypes
else:
self.cfunc.argtypes = argtypes
argtypes = property(_get_argtypes, _set_argtypes)
# restype property
def _get_restype(self):
return self.cfunc.restype
def _set_restype(self, restype):
self.cfunc.restype = restype
restype = property(_get_restype, _set_restype)
# errcheck property
def _get_errcheck(self):
return self.cfunc.errcheck
def _set_errcheck(self, errcheck):
self.cfunc.errcheck = errcheck
errcheck = property(_get_errcheck, _set_errcheck)

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
import ctypes import ctypes
import json import json
import random import random
import threading
from binascii import a2b_hex, b2a_hex from binascii import a2b_hex, b2a_hex
from io import BytesIO from io import BytesIO
from unittest import skipUnless from unittest import skipUnless
@ -12,7 +13,7 @@ from django.contrib.gis.gdal import HAS_GDAL
from django.contrib.gis.geos import ( from django.contrib.gis.geos import (
HAS_GEOS, GeometryCollection, GEOSException, GEOSGeometry, LinearRing, HAS_GEOS, GeometryCollection, GEOSException, GEOSGeometry, LinearRing,
LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon, LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon,
fromfile, fromstr, fromfile, fromstr, libgeos,
) )
from django.contrib.gis.geos.base import GEOSBase from django.contrib.gis.geos.base import GEOSBase
from django.contrib.gis.geos.libgeos import geos_version_info from django.contrib.gis.geos.libgeos import geos_version_info
@ -1232,6 +1233,48 @@ class GEOSTest(SimpleTestCase, TestDataMixin):
self.assertEqual(m.group('version'), v_geos) self.assertEqual(m.group('version'), v_geos)
self.assertEqual(m.group('capi_version'), v_capi) self.assertEqual(m.group('capi_version'), v_capi)
def test_geos_threads(self):
pnt = Point()
context_ptrs = []
geos_init = libgeos.lgeos.initGEOS_r
geos_finish = libgeos.lgeos.finishGEOS_r
def init(*args, **kwargs):
result = geos_init(*args, **kwargs)
context_ptrs.append(result)
return result
def finish(*args, **kwargs):
result = geos_finish(*args, **kwargs)
destructor_called.set()
return result
for i in range(2):
destructor_called = threading.Event()
patch_path = 'django.contrib.gis.geos.libgeos.lgeos'
with mock.patch.multiple(patch_path, initGEOS_r=mock.DEFAULT, finishGEOS_r=mock.DEFAULT) as mocked:
mocked['initGEOS_r'].side_effect = init
mocked['finishGEOS_r'].side_effect = finish
with mock.patch('django.contrib.gis.geos.prototypes.predicates.geos_hasz.func') as mocked_hasz:
thread = threading.Thread(target=lambda: pnt.hasz)
thread.start()
thread.join()
# We can't be sure that members of thread locals are
# garbage collected right after `thread.join()` so
# we must wait until destructor is actually called.
# Fail if destructor wasn't called within a second.
self.assertTrue(destructor_called.wait(1))
context_ptr = context_ptrs[i]
self.assertIsInstance(context_ptr, libgeos.CONTEXT_PTR)
mocked_hasz.assert_called_once_with(context_ptr, pnt.ptr)
mocked['finishGEOS_r'].assert_called_once_with(context_ptr)
# Check that different contexts were used for the different threads.
self.assertNotEqual(context_ptrs[0], context_ptrs[1])
@ignore_warnings(category=RemovedInDjango20Warning) @ignore_warnings(category=RemovedInDjango20Warning)
def test_deprecated_srid_getters_setters(self): def test_deprecated_srid_getters_setters(self):
p = Point(1, 2, srid=123) p = Point(1, 2, srid=123)