Fixed #31766 -- Made GDALRaster.transform() return a clone for the same SRID and driver.

Thanks Daniel Wiesmann for the review.
This commit is contained in:
Barton Ip 2020-08-30 11:35:11 +00:00 committed by Mariusz Felisiak
parent 5362e08624
commit 12d6cae7c0
3 changed files with 108 additions and 0 deletions

View File

@ -110,6 +110,7 @@ answer newbie questions, and generally made Django that much better:
Baptiste Mispelon <bmispelon@gmail.com>
Barry Pederson <bp@barryp.org>
Bartolome Sanchez Salado <i42sasab@uco.es>
Barton Ip <notbartonip@gmail.com>
Bartosz Grabski <bartosz.grabski@gmail.com>
Bashar Al-Abdulhadi
Bastian Kleineidam <calvin@debian.org>

View File

@ -425,6 +425,27 @@ class GDALRaster(GDALRasterBase):
return target
def clone(self, name=None):
"""Return a clone of this GDALRaster."""
if name:
clone_name = name
elif self.driver.name != 'MEM':
clone_name = self.name + '_copy.' + self.driver.name
else:
clone_name = os.path.join(VSI_FILESYSTEM_BASE_PATH, str(uuid.uuid4()))
return GDALRaster(
capi.copy_ds(
self.driver._ptr,
force_bytes(clone_name),
self._ptr,
c_int(),
c_char_p(),
c_void_p(),
c_void_p(),
),
write=self._write,
)
def transform(self, srs, driver=None, name=None, resampling='NearestNeighbour',
max_error=0.0):
"""
@ -443,6 +464,9 @@ class GDALRaster(GDALRasterBase):
'Transform only accepts SpatialReference, string, and integer '
'objects.'
)
if target_srs.srid == self.srid and (not driver or driver == self.driver.name):
return self.clone(name)
# Create warped virtual dataset in the target reference system
target = capi.auto_create_warped_vrt(
self._ptr, self.srs.wkt.encode(), target_srs.wkt.encode(),

View File

@ -2,6 +2,7 @@ import os
import shutil
import struct
import tempfile
from unittest import mock
from django.contrib.gis.gdal import GDAL_VERSION, GDALRaster, SpatialReference
from django.contrib.gis.gdal.error import GDALException
@ -470,6 +471,40 @@ class GDALRasterTests(SimpleTestCase):
# The result is an empty raster filled with the correct nodata value.
self.assertEqual(result, [23] * 16)
def test_raster_clone(self):
rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
tests = [
('MEM', '', 23), # In memory raster.
('tif', rstfile.name, 99), # In file based raster.
]
for driver, name, nodata_value in tests:
with self.subTest(driver=driver):
source = GDALRaster({
'datatype': 1,
'driver': driver,
'name': name,
'width': 4,
'height': 4,
'srid': 3086,
'origin': (500000, 400000),
'scale': (100, -100),
'skew': (0, 0),
'bands': [{
'data': range(16),
'nodata_value': nodata_value,
}],
})
clone = source.clone()
self.assertNotEqual(clone.name, source.name)
self.assertEqual(clone._write, source._write)
self.assertEqual(clone.srs.srid, source.srs.srid)
self.assertEqual(clone.width, source.width)
self.assertEqual(clone.height, source.height)
self.assertEqual(clone.origin, source.origin)
self.assertEqual(clone.scale, source.scale)
self.assertEqual(clone.skew, source.skew)
self.assertIsNot(clone, source)
def test_raster_transform(self):
tests = [
3086,
@ -531,6 +566,54 @@ class GDALRasterTests(SimpleTestCase):
],
)
def test_raster_transform_clone(self):
with mock.patch.object(GDALRaster, 'clone') as mocked_clone:
# Create in file based raster.
rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
source = GDALRaster({
'datatype': 1,
'driver': 'tif',
'name': rstfile.name,
'width': 5,
'height': 5,
'nr_of_bands': 1,
'srid': 4326,
'origin': (-5, 5),
'scale': (2, -2),
'skew': (0, 0),
'bands': [{
'data': range(25),
'nodata_value': 99,
}],
})
# transform() returns a clone because it is the same SRID and
# driver.
source.transform(4326)
self.assertEqual(mocked_clone.call_count, 1)
def test_raster_transform_clone_name(self):
# Create in file based raster.
rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
source = GDALRaster({
'datatype': 1,
'driver': 'tif',
'name': rstfile.name,
'width': 5,
'height': 5,
'nr_of_bands': 1,
'srid': 4326,
'origin': (-5, 5),
'scale': (2, -2),
'skew': (0, 0),
'bands': [{
'data': range(25),
'nodata_value': 99,
}],
})
clone_name = rstfile.name + '_respect_name.GTiff'
target = source.transform(4326, name=clone_name)
self.assertEqual(target.name, clone_name)
class GDALBandTests(SimpleTestCase):
rs_path = os.path.join(os.path.dirname(__file__), '../data/rasters/raster.tif')