Fixed #34411 -- Updated GDAL API to handle vector DataSource's.

Co-authored-by: David Smith <smithdc@gmail.com>
This commit is contained in:
Claude Paroz 2023-03-13 18:38:23 +01:00 committed by Mariusz Felisiak
parent 77278929c8
commit 08306bad57
5 changed files with 58 additions and 76 deletions

View File

@ -33,7 +33,6 @@
# OFTReal returns floats, all else returns string.
val = field.value
"""
from ctypes import byref
from pathlib import Path
from django.contrib.gis.gdal.base import GDALBase
@ -54,21 +53,22 @@ class DataSource(GDALBase):
def __init__(self, ds_input, ds_driver=False, write=False, encoding="utf-8"):
# The write flag.
if write:
self._write = 1
else:
self._write = 0
self._write = capi.GDAL_OF_UPDATE if write else capi.GDAL_OF_READONLY
# See also https://gdal.org/development/rfc/rfc23_ogr_unicode.html
self.encoding = encoding
Driver.ensure_registered()
if isinstance(ds_input, (str, Path)):
# The data source driver is a void pointer.
ds_driver = Driver.ptr_type()
try:
# OGROpen will auto-detect the data source type.
ds = capi.open_ds(force_bytes(ds_input), self._write, byref(ds_driver))
# GDALOpenEx will auto-detect the data source type.
ds = capi.open_ds(
force_bytes(ds_input),
self._write | capi.GDAL_OF_VECTOR,
None,
None,
None,
)
except GDALException:
# Making the error message more clear rather than something
# like "Invalid pointer returned from OGROpen".
@ -82,7 +82,8 @@ class DataSource(GDALBase):
if ds:
self.ptr = ds
self.driver = Driver(ds_driver)
driver = capi.get_dataset_driver(ds)
self.driver = Driver(driver)
else:
# Raise an exception if the returned pointer is NULL
raise GDALException('Invalid data source file "%s"' % ds_input)

View File

@ -2,8 +2,7 @@ from ctypes import c_void_p
from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import ds as vcapi
from django.contrib.gis.gdal.prototypes import raster as rcapi
from django.contrib.gis.gdal.prototypes import ds as capi
from django.utils.encoding import force_bytes, force_str
@ -49,16 +48,10 @@ class Driver(GDALBase):
name = dr_input
# Attempting to get the GDAL/OGR driver by the string name.
for iface in (vcapi, rcapi):
driver = c_void_p(iface.get_driver_by_name(force_bytes(name)))
if driver:
break
driver = c_void_p(capi.get_driver_by_name(force_bytes(name)))
elif isinstance(dr_input, int):
self.ensure_registered()
for iface in (vcapi, rcapi):
driver = iface.get_driver(dr_input)
if driver:
break
driver = capi.get_driver(dr_input)
elif isinstance(dr_input, c_void_p):
driver = dr_input
else:
@ -81,23 +74,21 @@ class Driver(GDALBase):
"""
Attempt to register all the data source drivers.
"""
# Only register all if the driver counts are 0 (or else all drivers
# will be registered over and over again)
if not vcapi.get_driver_count():
vcapi.register_all()
if not rcapi.get_driver_count():
rcapi.register_all()
# Only register all if the driver count is 0 (or else all drivers will
# be registered over and over again).
if not capi.get_driver_count():
capi.register_all()
@classmethod
def driver_count(cls):
"""
Return the number of GDAL/OGR data source drivers registered.
"""
return vcapi.get_driver_count() + rcapi.get_driver_count()
return capi.get_driver_count()
@property
def name(self):
"""
Return description/name string for this driver.
"""
return force_str(rcapi.get_driver_description(self.ptr))
return force_str(capi.get_driver_description(self.ptr))

View File

@ -3,7 +3,7 @@
related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*,
OGR_Fld_* routines are relevant here.
"""
from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_void_p
from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_uint, c_void_p
from django.contrib.gis.gdal.envelope import OGREnvelope
from django.contrib.gis.gdal.libgdal import lgdal
@ -21,26 +21,36 @@ from django.contrib.gis.gdal.prototypes.generation import (
c_int_p = POINTER(c_int) # shortcut type
GDAL_OF_READONLY = 0x00
GDAL_OF_UPDATE = 0x01
GDAL_OF_ALL = 0x00
GDAL_OF_RASTER = 0x02
GDAL_OF_VECTOR = 0x04
# Driver Routines
register_all = void_output(lgdal.OGRRegisterAll, [], errcheck=False)
cleanup_all = void_output(lgdal.OGRCleanupAll, [], errcheck=False)
get_driver = voidptr_output(lgdal.OGRGetDriver, [c_int])
register_all = void_output(lgdal.GDALAllRegister, [], errcheck=False)
cleanup_all = void_output(lgdal.GDALDestroyDriverManager, [], errcheck=False)
get_driver = voidptr_output(lgdal.GDALGetDriver, [c_int])
get_driver_by_name = voidptr_output(
lgdal.OGRGetDriverByName, [c_char_p], errcheck=False
)
get_driver_count = int_output(lgdal.OGRGetDriverCount, [])
get_driver_name = const_string_output(
lgdal.OGR_Dr_GetName, [c_void_p], decoding="ascii"
lgdal.GDALGetDriverByName, [c_char_p], errcheck=False
)
get_driver_count = int_output(lgdal.GDALGetDriverCount, [])
get_driver_description = const_string_output(lgdal.GDALGetDescription, [c_void_p])
# DataSource
open_ds = voidptr_output(lgdal.OGROpen, [c_char_p, c_int, POINTER(c_void_p)])
destroy_ds = void_output(lgdal.OGR_DS_Destroy, [c_void_p], errcheck=False)
release_ds = void_output(lgdal.OGRReleaseDataSource, [c_void_p])
get_ds_name = const_string_output(lgdal.OGR_DS_GetName, [c_void_p])
get_layer = voidptr_output(lgdal.OGR_DS_GetLayer, [c_void_p, c_int])
get_layer_by_name = voidptr_output(lgdal.OGR_DS_GetLayerByName, [c_void_p, c_char_p])
get_layer_count = int_output(lgdal.OGR_DS_GetLayerCount, [c_void_p])
open_ds = voidptr_output(
lgdal.GDALOpenEx,
[c_char_p, c_uint, POINTER(c_char_p), POINTER(c_char_p), POINTER(c_char_p)],
)
destroy_ds = void_output(lgdal.GDALClose, [c_void_p], errcheck=False)
get_ds_name = const_string_output(lgdal.GDALGetDescription, [c_void_p])
get_dataset_driver = voidptr_output(lgdal.GDALGetDatasetDriver, [c_void_p])
get_layer = voidptr_output(lgdal.GDALDatasetGetLayer, [c_void_p, c_int])
get_layer_by_name = voidptr_output(
lgdal.GDALDatasetGetLayerByName, [c_void_p, c_char_p]
)
get_layer_count = int_output(lgdal.GDALDatasetGetLayerCount, [c_void_p])
# Layer Routines
get_extent = void_output(lgdal.OGR_L_GetExtent, [c_void_p, POINTER(OGREnvelope), c_int])

View File

@ -25,15 +25,6 @@ void_output = partial(void_output, cpl=True)
const_string_output = partial(const_string_output, cpl=True)
double_output = partial(double_output, cpl=True)
# Raster Driver Routines
register_all = void_output(std_call("GDALAllRegister"), [], errcheck=False)
get_driver = voidptr_output(std_call("GDALGetDriver"), [c_int])
get_driver_by_name = voidptr_output(
std_call("GDALGetDriverByName"), [c_char_p], errcheck=False
)
get_driver_count = int_output(std_call("GDALGetDriverCount"), [])
get_driver_description = const_string_output(std_call("GDALGetDescription"), [c_void_p])
# Raster Data Source Routines
create_ds = voidptr_output(
std_call("GDALCreate"), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p]

View File

@ -54,32 +54,21 @@ class DriverTest(unittest.TestCase):
dr = Driver(alias)
self.assertEqual(full_name, str(dr))
@mock.patch("django.contrib.gis.gdal.driver.vcapi.get_driver_count")
@mock.patch("django.contrib.gis.gdal.driver.rcapi.get_driver_count")
@mock.patch("django.contrib.gis.gdal.driver.vcapi.register_all")
@mock.patch("django.contrib.gis.gdal.driver.rcapi.register_all")
def test_registered(self, rreg, vreg, rcount, vcount):
@mock.patch("django.contrib.gis.gdal.driver.capi.get_driver_count")
@mock.patch("django.contrib.gis.gdal.driver.capi.register_all")
def test_registered(self, reg, count):
"""
Prototypes are registered only if their respective driver counts are
zero.
Prototypes are registered only if the driver count is zero.
"""
def check(rcount_val, vcount_val):
vreg.reset_mock()
rreg.reset_mock()
rcount.return_value = rcount_val
vcount.return_value = vcount_val
def check(count_val):
reg.reset_mock()
count.return_value = count_val
Driver.ensure_registered()
if rcount_val:
self.assertFalse(rreg.called)
if count_val:
self.assertFalse(reg.called)
else:
rreg.assert_called_once_with()
if vcount_val:
self.assertFalse(vreg.called)
else:
vreg.assert_called_once_with()
reg.assert_called_once_with()
check(0, 0)
check(120, 0)
check(0, 120)
check(120, 120)
check(0)
check(120)