diff --git a/django/contrib/gis/gdal/datasource.py b/django/contrib/gis/gdal/datasource.py index dfd043ab0c2..89007771173 100644 --- a/django/contrib/gis/gdal/datasource.py +++ b/django/contrib/gis/gdal/datasource.py @@ -33,7 +33,6 @@ # OFTReal returns floats, all else returns string. val = field.value """ -from ctypes import byref from pathlib import Path from django.contrib.gis.gdal.base import GDALBase @@ -54,21 +53,22 @@ class DataSource(GDALBase): def __init__(self, ds_input, ds_driver=False, write=False, encoding="utf-8"): # The write flag. - if write: - self._write = 1 - else: - self._write = 0 + self._write = capi.GDAL_OF_UPDATE if write else capi.GDAL_OF_READONLY # See also https://gdal.org/development/rfc/rfc23_ogr_unicode.html self.encoding = encoding Driver.ensure_registered() if isinstance(ds_input, (str, Path)): - # The data source driver is a void pointer. - ds_driver = Driver.ptr_type() try: - # OGROpen will auto-detect the data source type. - ds = capi.open_ds(force_bytes(ds_input), self._write, byref(ds_driver)) + # GDALOpenEx will auto-detect the data source type. + ds = capi.open_ds( + force_bytes(ds_input), + self._write | capi.GDAL_OF_VECTOR, + None, + None, + None, + ) except GDALException: # Making the error message more clear rather than something # like "Invalid pointer returned from OGROpen". @@ -82,7 +82,8 @@ class DataSource(GDALBase): if ds: self.ptr = ds - self.driver = Driver(ds_driver) + driver = capi.get_dataset_driver(ds) + self.driver = Driver(driver) else: # Raise an exception if the returned pointer is NULL raise GDALException('Invalid data source file "%s"' % ds_input) diff --git a/django/contrib/gis/gdal/driver.py b/django/contrib/gis/gdal/driver.py index 0ce7a2cdc87..1d4ca4c07a4 100644 --- a/django/contrib/gis/gdal/driver.py +++ b/django/contrib/gis/gdal/driver.py @@ -2,8 +2,7 @@ from ctypes import c_void_p from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.error import GDALException -from django.contrib.gis.gdal.prototypes import ds as vcapi -from django.contrib.gis.gdal.prototypes import raster as rcapi +from django.contrib.gis.gdal.prototypes import ds as capi from django.utils.encoding import force_bytes, force_str @@ -49,16 +48,10 @@ class Driver(GDALBase): name = dr_input # Attempting to get the GDAL/OGR driver by the string name. - for iface in (vcapi, rcapi): - driver = c_void_p(iface.get_driver_by_name(force_bytes(name))) - if driver: - break + driver = c_void_p(capi.get_driver_by_name(force_bytes(name))) elif isinstance(dr_input, int): self.ensure_registered() - for iface in (vcapi, rcapi): - driver = iface.get_driver(dr_input) - if driver: - break + driver = capi.get_driver(dr_input) elif isinstance(dr_input, c_void_p): driver = dr_input else: @@ -81,23 +74,21 @@ class Driver(GDALBase): """ Attempt to register all the data source drivers. """ - # Only register all if the driver counts are 0 (or else all drivers - # will be registered over and over again) - if not vcapi.get_driver_count(): - vcapi.register_all() - if not rcapi.get_driver_count(): - rcapi.register_all() + # Only register all if the driver count is 0 (or else all drivers will + # be registered over and over again). + if not capi.get_driver_count(): + capi.register_all() @classmethod def driver_count(cls): """ Return the number of GDAL/OGR data source drivers registered. """ - return vcapi.get_driver_count() + rcapi.get_driver_count() + return capi.get_driver_count() @property def name(self): """ Return description/name string for this driver. """ - return force_str(rcapi.get_driver_description(self.ptr)) + return force_str(capi.get_driver_description(self.ptr)) diff --git a/django/contrib/gis/gdal/prototypes/ds.py b/django/contrib/gis/gdal/prototypes/ds.py index bc5250e2db2..e3ef2699e9b 100644 --- a/django/contrib/gis/gdal/prototypes/ds.py +++ b/django/contrib/gis/gdal/prototypes/ds.py @@ -3,7 +3,7 @@ related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*, OGR_Fld_* routines are relevant here. """ -from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_void_p +from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_uint, c_void_p from django.contrib.gis.gdal.envelope import OGREnvelope from django.contrib.gis.gdal.libgdal import lgdal @@ -21,26 +21,36 @@ from django.contrib.gis.gdal.prototypes.generation import ( c_int_p = POINTER(c_int) # shortcut type +GDAL_OF_READONLY = 0x00 +GDAL_OF_UPDATE = 0x01 + +GDAL_OF_ALL = 0x00 +GDAL_OF_RASTER = 0x02 +GDAL_OF_VECTOR = 0x04 + # Driver Routines -register_all = void_output(lgdal.OGRRegisterAll, [], errcheck=False) -cleanup_all = void_output(lgdal.OGRCleanupAll, [], errcheck=False) -get_driver = voidptr_output(lgdal.OGRGetDriver, [c_int]) +register_all = void_output(lgdal.GDALAllRegister, [], errcheck=False) +cleanup_all = void_output(lgdal.GDALDestroyDriverManager, [], errcheck=False) +get_driver = voidptr_output(lgdal.GDALGetDriver, [c_int]) get_driver_by_name = voidptr_output( - lgdal.OGRGetDriverByName, [c_char_p], errcheck=False -) -get_driver_count = int_output(lgdal.OGRGetDriverCount, []) -get_driver_name = const_string_output( - lgdal.OGR_Dr_GetName, [c_void_p], decoding="ascii" + lgdal.GDALGetDriverByName, [c_char_p], errcheck=False ) +get_driver_count = int_output(lgdal.GDALGetDriverCount, []) +get_driver_description = const_string_output(lgdal.GDALGetDescription, [c_void_p]) # DataSource -open_ds = voidptr_output(lgdal.OGROpen, [c_char_p, c_int, POINTER(c_void_p)]) -destroy_ds = void_output(lgdal.OGR_DS_Destroy, [c_void_p], errcheck=False) -release_ds = void_output(lgdal.OGRReleaseDataSource, [c_void_p]) -get_ds_name = const_string_output(lgdal.OGR_DS_GetName, [c_void_p]) -get_layer = voidptr_output(lgdal.OGR_DS_GetLayer, [c_void_p, c_int]) -get_layer_by_name = voidptr_output(lgdal.OGR_DS_GetLayerByName, [c_void_p, c_char_p]) -get_layer_count = int_output(lgdal.OGR_DS_GetLayerCount, [c_void_p]) +open_ds = voidptr_output( + lgdal.GDALOpenEx, + [c_char_p, c_uint, POINTER(c_char_p), POINTER(c_char_p), POINTER(c_char_p)], +) +destroy_ds = void_output(lgdal.GDALClose, [c_void_p], errcheck=False) +get_ds_name = const_string_output(lgdal.GDALGetDescription, [c_void_p]) +get_dataset_driver = voidptr_output(lgdal.GDALGetDatasetDriver, [c_void_p]) +get_layer = voidptr_output(lgdal.GDALDatasetGetLayer, [c_void_p, c_int]) +get_layer_by_name = voidptr_output( + lgdal.GDALDatasetGetLayerByName, [c_void_p, c_char_p] +) +get_layer_count = int_output(lgdal.GDALDatasetGetLayerCount, [c_void_p]) # Layer Routines get_extent = void_output(lgdal.OGR_L_GetExtent, [c_void_p, POINTER(OGREnvelope), c_int]) diff --git a/django/contrib/gis/gdal/prototypes/raster.py b/django/contrib/gis/gdal/prototypes/raster.py index 59b930cb02d..17ee4a1926a 100644 --- a/django/contrib/gis/gdal/prototypes/raster.py +++ b/django/contrib/gis/gdal/prototypes/raster.py @@ -25,15 +25,6 @@ void_output = partial(void_output, cpl=True) const_string_output = partial(const_string_output, cpl=True) double_output = partial(double_output, cpl=True) -# Raster Driver Routines -register_all = void_output(std_call("GDALAllRegister"), [], errcheck=False) -get_driver = voidptr_output(std_call("GDALGetDriver"), [c_int]) -get_driver_by_name = voidptr_output( - std_call("GDALGetDriverByName"), [c_char_p], errcheck=False -) -get_driver_count = int_output(std_call("GDALGetDriverCount"), []) -get_driver_description = const_string_output(std_call("GDALGetDescription"), [c_void_p]) - # Raster Data Source Routines create_ds = voidptr_output( std_call("GDALCreate"), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p] diff --git a/tests/gis_tests/gdal_tests/test_driver.py b/tests/gis_tests/gdal_tests/test_driver.py index e7c03ae98d4..0d9423bc854 100644 --- a/tests/gis_tests/gdal_tests/test_driver.py +++ b/tests/gis_tests/gdal_tests/test_driver.py @@ -54,32 +54,21 @@ class DriverTest(unittest.TestCase): dr = Driver(alias) self.assertEqual(full_name, str(dr)) - @mock.patch("django.contrib.gis.gdal.driver.vcapi.get_driver_count") - @mock.patch("django.contrib.gis.gdal.driver.rcapi.get_driver_count") - @mock.patch("django.contrib.gis.gdal.driver.vcapi.register_all") - @mock.patch("django.contrib.gis.gdal.driver.rcapi.register_all") - def test_registered(self, rreg, vreg, rcount, vcount): + @mock.patch("django.contrib.gis.gdal.driver.capi.get_driver_count") + @mock.patch("django.contrib.gis.gdal.driver.capi.register_all") + def test_registered(self, reg, count): """ - Prototypes are registered only if their respective driver counts are - zero. + Prototypes are registered only if the driver count is zero. """ - def check(rcount_val, vcount_val): - vreg.reset_mock() - rreg.reset_mock() - rcount.return_value = rcount_val - vcount.return_value = vcount_val + def check(count_val): + reg.reset_mock() + count.return_value = count_val Driver.ensure_registered() - if rcount_val: - self.assertFalse(rreg.called) + if count_val: + self.assertFalse(reg.called) else: - rreg.assert_called_once_with() - if vcount_val: - self.assertFalse(vreg.called) - else: - vreg.assert_called_once_with() + reg.assert_called_once_with() - check(0, 0) - check(120, 0) - check(0, 120) - check(120, 120) + check(0) + check(120)