diff --git a/django/contrib/gis/gdal/raster/source.py b/django/contrib/gis/gdal/raster/source.py index 904d2f783a..432a1eb7f4 100644 --- a/django/contrib/gis/gdal/raster/source.py +++ b/django/contrib/gis/gdal/raster/source.py @@ -108,9 +108,22 @@ class GDALRaster(GDALBase): # Set band data if provided for i, band_input in enumerate(ds_input.get('bands', [])): band = self.bands[i] - band.data(band_input['data']) if 'nodata_value' in band_input: band.nodata_value = band_input['nodata_value'] + # Instantiate band filled with nodata values if only + # partial input data has been provided. + if band.nodata_value is not None and ( + 'data' not in band_input or + 'size' in band_input or + 'shape' in band_input): + band.data(data=(band.nodata_value,), shape=(1, 1)) + # Set band data values from input. + band.data( + data=band_input.get('data'), + size=band_input.get('size'), + shape=band_input.get('shape'), + offset=band_input.get('offset'), + ) # Set SRID self.srs = ds_input.get('srid') @@ -339,16 +352,12 @@ class GDALRaster(GDALBase): if 'datatype' not in ds_input: ds_input['datatype'] = self.bands[0].datatype() - # Set the number of bands - ds_input['nr_of_bands'] = len(self.bands) + # Instantiate raster bands filled with nodata values. + ds_input['bands'] = [{'nodata_value': bnd.nodata_value} for bnd in self.bands] # Create target raster target = GDALRaster(ds_input, write=True) - # Copy nodata values to warped raster - for index, band in enumerate(self.bands): - target.bands[index].nodata_value = band.nodata_value - # Select resampling algorithm algorithm = GDAL_RESAMPLE_ALGORITHMS[resampling] diff --git a/docs/ref/contrib/gis/gdal.txt b/docs/ref/contrib/gis/gdal.txt index 9bd2df8dad..3f9219dc0d 100644 --- a/docs/ref/contrib/gis/gdal.txt +++ b/docs/ref/contrib/gis/gdal.txt @@ -1117,16 +1117,39 @@ blue. >>> rst = GDALRaster('/path/to/your/raster.tif', write=False) >>> rst.name '/path/to/your/raster.tif' - >>> rst.width, rst.height # This file has 163 x 174 pixels + >>> rst.width, rst.height # This file has 163 x 174 pixels (163, 174) - >>> rst = GDALRaster({'srid': 4326, 'width': 1, 'height': 2, 'datatype': 1 - ... 'bands': [{'data': [0, 1]}]}) # Creates in-memory raster + >>> rst = GDALRaster({ # Creates an in-memory raster + ... 'srid': 4326, + ... 'width': 4, + ... 'height': 4, + ... 'datatype': 1, + ... 'bands': [{ + ... 'data': (2, 3), + ... 'offset': (1, 1), + ... 'size': (2, 2), + ... 'shape': (2, 1), + ... 'nodata_value': 5, + ... }] + ... }) >>> rst.srs.srid 4326 >>> rst.width, rst.height - (1, 2) + (4, 4) >>> rst.bands[0].data() - array([[0, 1]], dtype=int8) + array([[5, 5, 5, 5], + [5, 2, 3, 5], + [5, 2, 3, 5], + [5, 5, 5, 5]], dtype=uint8) + + .. versionchanged:: 1.11 + + Added the ability to pass the ``size``, ``shape``, and ``offset`` + parameters when creating :class:`GDALRaster` objects. The parameters + can be passed through the ``ds_input`` dictionary. This allows to + finely control initial pixel values. The functionality is similar to + the :meth:`GDALBand.data()` + method. .. attribute:: name diff --git a/docs/releases/1.11.txt b/docs/releases/1.11.txt index a612cb0e12..77a1c3ea12 100644 --- a/docs/releases/1.11.txt +++ b/docs/releases/1.11.txt @@ -148,6 +148,9 @@ Minor features * PostGIS migrations can now change field dimensions. +* Added the ability to pass the `size`, `shape`, and `offset` parameter when + creating :class:`~django.contrib.gis.gdal.GDALRaster` objects. + :mod:`django.contrib.messages` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/gis_tests/gdal_tests/test_raster.py b/tests/gis_tests/gdal_tests/test_raster.py index 53f1bf985a..bddb3c8826 100644 --- a/tests/gis_tests/gdal_tests/test_raster.py +++ b/tests/gis_tests/gdal_tests/test_raster.py @@ -190,6 +190,64 @@ class GDALRasterTests(unittest.TestCase): else: self.assertEqual(restored_raster.bands[0].data(), self.rs.bands[0].data()) + def test_offset_size_and_shape_on_raster_creation(self): + rast = GDALRaster({ + 'datatype': 1, + 'width': 4, + 'height': 4, + 'srid': 4326, + 'bands': [{ + 'data': (1,), + 'offset': (1, 1), + 'size': (2, 2), + 'shape': (1, 1), + 'nodata_value': 2, + }], + }) + # Get array from raster. + result = rast.bands[0].data() + if numpy: + result = result.flatten().tolist() + # Band data is equal to nodata value except on input block of ones. + self.assertEqual( + result, + [2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2] + ) + + def test_set_nodata_value_on_raster_creation(self): + # Create raster filled with nodata values. + rast = GDALRaster({ + 'datatype': 1, + 'width': 2, + 'height': 2, + 'srid': 4326, + 'bands': [{'nodata_value': 23}], + }) + # Get array from raster. + result = rast.bands[0].data() + if numpy: + result = result.flatten().tolist() + # All band data is equal to nodata value. + self.assertEqual(result, [23, ] * 4) + + def test_set_nodata_none_on_raster_creation(self): + if GDAL_VERSION < (2, 1): + self.skipTest("GDAL >= 2.1 is required for this test.") + # Create raster without data and without nodata value. + rast = GDALRaster({ + 'datatype': 1, + 'width': 2, + 'height': 2, + 'srid': 4326, + 'bands': [{'nodata_value': None}], + }) + # Get array from raster. + result = rast.bands[0].data() + if numpy: + result = result.flatten().tolist() + # Band data is equal to zero becaues no nodata value has been specified. + self.assertEqual(result, [0] * 4) + def test_raster_warp(self): # Create in memory raster source = GDALRaster({ @@ -246,6 +304,29 @@ class GDALRasterTests(unittest.TestCase): 12.0, 13.0, 14.0, 15.0] ) + def test_raster_warp_nodata_zone(self): + # Create in memory raster. + source = GDALRaster({ + 'datatype': 1, + 'driver': 'MEM', + 'width': 4, + 'height': 4, + 'srid': 3086, + 'origin': (500000, 400000), + 'scale': (100, -100), + 'skew': (0, 0), + 'bands': [{ + 'data': range(16), + 'nodata_value': 23, + }], + }) + # Warp raster onto a location that does not cover any pixels of the original. + result = source.warp({'origin': (200000, 200000)}).bands[0].data() + if numpy: + result = result.flatten().tolist() + # The result is an empty raster filled with the correct nodata value. + self.assertEqual(result, [23] * 16) + def test_raster_transform(self): if GDAL_VERSION < (1, 8, 1): self.skipTest("GDAL >= 1.8.1 is required for this test")