Refs #23804 -- Improved value validation in GDALRaster.geotransform setter.

This commit is contained in:
Adam Johnson 2017-05-30 15:02:50 +02:00 committed by Tim Graham
parent 037d6540ec
commit 9509268cea
2 changed files with 16 additions and 1 deletions

View File

@ -250,7 +250,7 @@ class GDALRaster(GDALBase):
@geotransform.setter @geotransform.setter
def geotransform(self, values): def geotransform(self, values):
"Set the geotransform for the data source." "Set the geotransform for the data source."
if sum([isinstance(x, (int, float)) for x in values]) != 6: if len(values) != 6 or not all(isinstance(x, (int, float)) for x in values):
raise ValueError('Geotransform must consist of 6 numeric values.') raise ValueError('Geotransform must consist of 6 numeric values.')
# Create ctypes double array with input and write data # Create ctypes double array with input and write data
values = (c_double * 6)(*values) values = (c_double * 6)(*values)

View File

@ -103,6 +103,9 @@ class GDALRasterTests(SimpleTestCase):
self.assertEqual(self.rs.skew.y, 0) self.assertEqual(self.rs.skew.y, 0)
# Create in-memory rasters and change gtvalues # Create in-memory rasters and change gtvalues
rsmem = GDALRaster(JSON_RASTER) rsmem = GDALRaster(JSON_RASTER)
# geotransform accepts both floats and ints
rsmem.geotransform = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
self.assertEqual(rsmem.geotransform, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
rsmem.geotransform = range(6) rsmem.geotransform = range(6)
self.assertEqual(rsmem.geotransform, [float(x) for x in range(6)]) self.assertEqual(rsmem.geotransform, [float(x) for x in range(6)])
self.assertEqual(rsmem.origin, [0, 3]) self.assertEqual(rsmem.origin, [0, 3])
@ -117,6 +120,18 @@ class GDALRasterTests(SimpleTestCase):
self.assertEqual(rsmem.width, 5) self.assertEqual(rsmem.width, 5)
self.assertEqual(rsmem.height, 5) self.assertEqual(rsmem.height, 5)
def test_geotransform_bad_inputs(self):
rsmem = GDALRaster(JSON_RASTER)
error_geotransforms = [
[1, 2],
[1, 2, 3, 4, 5, 'foo'],
[1, 2, 3, 4, 5, 6, 'foo'],
]
msg = 'Geotransform must consist of 6 numeric values.'
for geotransform in error_geotransforms:
with self.subTest(i=geotransform), self.assertRaisesMessage(ValueError, msg):
rsmem.geotransform = geotransform
def test_rs_extent(self): def test_rs_extent(self):
self.assertEqual( self.assertEqual(
self.rs.extent, self.rs.extent,