mirror of https://github.com/django/django.git
Fixed #34411 -- Updated GDAL API to handle vector DataSource's.
Co-authored-by: David Smith <smithdc@gmail.com>
This commit is contained in:
parent
77278929c8
commit
08306bad57
|
@ -33,7 +33,6 @@
|
|||
# OFTReal returns floats, all else returns string.
|
||||
val = field.value
|
||||
"""
|
||||
from ctypes import byref
|
||||
from pathlib import Path
|
||||
|
||||
from django.contrib.gis.gdal.base import GDALBase
|
||||
|
@ -54,21 +53,22 @@ class DataSource(GDALBase):
|
|||
|
||||
def __init__(self, ds_input, ds_driver=False, write=False, encoding="utf-8"):
|
||||
# The write flag.
|
||||
if write:
|
||||
self._write = 1
|
||||
else:
|
||||
self._write = 0
|
||||
self._write = capi.GDAL_OF_UPDATE if write else capi.GDAL_OF_READONLY
|
||||
# See also https://gdal.org/development/rfc/rfc23_ogr_unicode.html
|
||||
self.encoding = encoding
|
||||
|
||||
Driver.ensure_registered()
|
||||
|
||||
if isinstance(ds_input, (str, Path)):
|
||||
# The data source driver is a void pointer.
|
||||
ds_driver = Driver.ptr_type()
|
||||
try:
|
||||
# OGROpen will auto-detect the data source type.
|
||||
ds = capi.open_ds(force_bytes(ds_input), self._write, byref(ds_driver))
|
||||
# GDALOpenEx will auto-detect the data source type.
|
||||
ds = capi.open_ds(
|
||||
force_bytes(ds_input),
|
||||
self._write | capi.GDAL_OF_VECTOR,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
except GDALException:
|
||||
# Making the error message more clear rather than something
|
||||
# like "Invalid pointer returned from OGROpen".
|
||||
|
@ -82,7 +82,8 @@ class DataSource(GDALBase):
|
|||
|
||||
if ds:
|
||||
self.ptr = ds
|
||||
self.driver = Driver(ds_driver)
|
||||
driver = capi.get_dataset_driver(ds)
|
||||
self.driver = Driver(driver)
|
||||
else:
|
||||
# Raise an exception if the returned pointer is NULL
|
||||
raise GDALException('Invalid data source file "%s"' % ds_input)
|
||||
|
|
|
@ -2,8 +2,7 @@ from ctypes import c_void_p
|
|||
|
||||
from django.contrib.gis.gdal.base import GDALBase
|
||||
from django.contrib.gis.gdal.error import GDALException
|
||||
from django.contrib.gis.gdal.prototypes import ds as vcapi
|
||||
from django.contrib.gis.gdal.prototypes import raster as rcapi
|
||||
from django.contrib.gis.gdal.prototypes import ds as capi
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
|
||||
|
||||
|
@ -49,16 +48,10 @@ class Driver(GDALBase):
|
|||
name = dr_input
|
||||
|
||||
# Attempting to get the GDAL/OGR driver by the string name.
|
||||
for iface in (vcapi, rcapi):
|
||||
driver = c_void_p(iface.get_driver_by_name(force_bytes(name)))
|
||||
if driver:
|
||||
break
|
||||
driver = c_void_p(capi.get_driver_by_name(force_bytes(name)))
|
||||
elif isinstance(dr_input, int):
|
||||
self.ensure_registered()
|
||||
for iface in (vcapi, rcapi):
|
||||
driver = iface.get_driver(dr_input)
|
||||
if driver:
|
||||
break
|
||||
driver = capi.get_driver(dr_input)
|
||||
elif isinstance(dr_input, c_void_p):
|
||||
driver = dr_input
|
||||
else:
|
||||
|
@ -81,23 +74,21 @@ class Driver(GDALBase):
|
|||
"""
|
||||
Attempt to register all the data source drivers.
|
||||
"""
|
||||
# Only register all if the driver counts are 0 (or else all drivers
|
||||
# will be registered over and over again)
|
||||
if not vcapi.get_driver_count():
|
||||
vcapi.register_all()
|
||||
if not rcapi.get_driver_count():
|
||||
rcapi.register_all()
|
||||
# Only register all if the driver count is 0 (or else all drivers will
|
||||
# be registered over and over again).
|
||||
if not capi.get_driver_count():
|
||||
capi.register_all()
|
||||
|
||||
@classmethod
|
||||
def driver_count(cls):
|
||||
"""
|
||||
Return the number of GDAL/OGR data source drivers registered.
|
||||
"""
|
||||
return vcapi.get_driver_count() + rcapi.get_driver_count()
|
||||
return capi.get_driver_count()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
Return description/name string for this driver.
|
||||
"""
|
||||
return force_str(rcapi.get_driver_description(self.ptr))
|
||||
return force_str(capi.get_driver_description(self.ptr))
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*,
|
||||
OGR_Fld_* routines are relevant here.
|
||||
"""
|
||||
from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_void_p
|
||||
from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_uint, c_void_p
|
||||
|
||||
from django.contrib.gis.gdal.envelope import OGREnvelope
|
||||
from django.contrib.gis.gdal.libgdal import lgdal
|
||||
|
@ -21,26 +21,36 @@ from django.contrib.gis.gdal.prototypes.generation import (
|
|||
|
||||
c_int_p = POINTER(c_int) # shortcut type
|
||||
|
||||
GDAL_OF_READONLY = 0x00
|
||||
GDAL_OF_UPDATE = 0x01
|
||||
|
||||
GDAL_OF_ALL = 0x00
|
||||
GDAL_OF_RASTER = 0x02
|
||||
GDAL_OF_VECTOR = 0x04
|
||||
|
||||
# Driver Routines
|
||||
register_all = void_output(lgdal.OGRRegisterAll, [], errcheck=False)
|
||||
cleanup_all = void_output(lgdal.OGRCleanupAll, [], errcheck=False)
|
||||
get_driver = voidptr_output(lgdal.OGRGetDriver, [c_int])
|
||||
register_all = void_output(lgdal.GDALAllRegister, [], errcheck=False)
|
||||
cleanup_all = void_output(lgdal.GDALDestroyDriverManager, [], errcheck=False)
|
||||
get_driver = voidptr_output(lgdal.GDALGetDriver, [c_int])
|
||||
get_driver_by_name = voidptr_output(
|
||||
lgdal.OGRGetDriverByName, [c_char_p], errcheck=False
|
||||
)
|
||||
get_driver_count = int_output(lgdal.OGRGetDriverCount, [])
|
||||
get_driver_name = const_string_output(
|
||||
lgdal.OGR_Dr_GetName, [c_void_p], decoding="ascii"
|
||||
lgdal.GDALGetDriverByName, [c_char_p], errcheck=False
|
||||
)
|
||||
get_driver_count = int_output(lgdal.GDALGetDriverCount, [])
|
||||
get_driver_description = const_string_output(lgdal.GDALGetDescription, [c_void_p])
|
||||
|
||||
# DataSource
|
||||
open_ds = voidptr_output(lgdal.OGROpen, [c_char_p, c_int, POINTER(c_void_p)])
|
||||
destroy_ds = void_output(lgdal.OGR_DS_Destroy, [c_void_p], errcheck=False)
|
||||
release_ds = void_output(lgdal.OGRReleaseDataSource, [c_void_p])
|
||||
get_ds_name = const_string_output(lgdal.OGR_DS_GetName, [c_void_p])
|
||||
get_layer = voidptr_output(lgdal.OGR_DS_GetLayer, [c_void_p, c_int])
|
||||
get_layer_by_name = voidptr_output(lgdal.OGR_DS_GetLayerByName, [c_void_p, c_char_p])
|
||||
get_layer_count = int_output(lgdal.OGR_DS_GetLayerCount, [c_void_p])
|
||||
open_ds = voidptr_output(
|
||||
lgdal.GDALOpenEx,
|
||||
[c_char_p, c_uint, POINTER(c_char_p), POINTER(c_char_p), POINTER(c_char_p)],
|
||||
)
|
||||
destroy_ds = void_output(lgdal.GDALClose, [c_void_p], errcheck=False)
|
||||
get_ds_name = const_string_output(lgdal.GDALGetDescription, [c_void_p])
|
||||
get_dataset_driver = voidptr_output(lgdal.GDALGetDatasetDriver, [c_void_p])
|
||||
get_layer = voidptr_output(lgdal.GDALDatasetGetLayer, [c_void_p, c_int])
|
||||
get_layer_by_name = voidptr_output(
|
||||
lgdal.GDALDatasetGetLayerByName, [c_void_p, c_char_p]
|
||||
)
|
||||
get_layer_count = int_output(lgdal.GDALDatasetGetLayerCount, [c_void_p])
|
||||
|
||||
# Layer Routines
|
||||
get_extent = void_output(lgdal.OGR_L_GetExtent, [c_void_p, POINTER(OGREnvelope), c_int])
|
||||
|
|
|
@ -25,15 +25,6 @@ void_output = partial(void_output, cpl=True)
|
|||
const_string_output = partial(const_string_output, cpl=True)
|
||||
double_output = partial(double_output, cpl=True)
|
||||
|
||||
# Raster Driver Routines
|
||||
register_all = void_output(std_call("GDALAllRegister"), [], errcheck=False)
|
||||
get_driver = voidptr_output(std_call("GDALGetDriver"), [c_int])
|
||||
get_driver_by_name = voidptr_output(
|
||||
std_call("GDALGetDriverByName"), [c_char_p], errcheck=False
|
||||
)
|
||||
get_driver_count = int_output(std_call("GDALGetDriverCount"), [])
|
||||
get_driver_description = const_string_output(std_call("GDALGetDescription"), [c_void_p])
|
||||
|
||||
# Raster Data Source Routines
|
||||
create_ds = voidptr_output(
|
||||
std_call("GDALCreate"), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p]
|
||||
|
|
|
@ -54,32 +54,21 @@ class DriverTest(unittest.TestCase):
|
|||
dr = Driver(alias)
|
||||
self.assertEqual(full_name, str(dr))
|
||||
|
||||
@mock.patch("django.contrib.gis.gdal.driver.vcapi.get_driver_count")
|
||||
@mock.patch("django.contrib.gis.gdal.driver.rcapi.get_driver_count")
|
||||
@mock.patch("django.contrib.gis.gdal.driver.vcapi.register_all")
|
||||
@mock.patch("django.contrib.gis.gdal.driver.rcapi.register_all")
|
||||
def test_registered(self, rreg, vreg, rcount, vcount):
|
||||
@mock.patch("django.contrib.gis.gdal.driver.capi.get_driver_count")
|
||||
@mock.patch("django.contrib.gis.gdal.driver.capi.register_all")
|
||||
def test_registered(self, reg, count):
|
||||
"""
|
||||
Prototypes are registered only if their respective driver counts are
|
||||
zero.
|
||||
Prototypes are registered only if the driver count is zero.
|
||||
"""
|
||||
|
||||
def check(rcount_val, vcount_val):
|
||||
vreg.reset_mock()
|
||||
rreg.reset_mock()
|
||||
rcount.return_value = rcount_val
|
||||
vcount.return_value = vcount_val
|
||||
def check(count_val):
|
||||
reg.reset_mock()
|
||||
count.return_value = count_val
|
||||
Driver.ensure_registered()
|
||||
if rcount_val:
|
||||
self.assertFalse(rreg.called)
|
||||
if count_val:
|
||||
self.assertFalse(reg.called)
|
||||
else:
|
||||
rreg.assert_called_once_with()
|
||||
if vcount_val:
|
||||
self.assertFalse(vreg.called)
|
||||
else:
|
||||
vreg.assert_called_once_with()
|
||||
reg.assert_called_once_with()
|
||||
|
||||
check(0, 0)
|
||||
check(120, 0)
|
||||
check(0, 120)
|
||||
check(120, 120)
|
||||
check(0)
|
||||
check(120)
|
||||
|
|
Loading…
Reference in New Issue