diff --git a/django/contrib/gis/gdal/raster/band.py b/django/contrib/gis/gdal/raster/band.py index f1eb50e506d..bd273b29855 100644 --- a/django/contrib/gis/gdal/raster/band.py +++ b/django/contrib/gis/gdal/raster/band.py @@ -1,10 +1,12 @@ from ctypes import byref, c_int from django.contrib.gis.gdal.base import GDALBase +from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.prototypes import raster as capi from django.contrib.gis.shortcuts import numpy from django.utils import six from django.utils.encoding import force_text +from django.utils.six.moves import range from .const import GDAL_INTEGER_TYPES, GDAL_PIXEL_TYPES, GDAL_TO_CTYPES @@ -148,3 +150,22 @@ class GDALBand(GDALBase): return list(data_array) else: self.source._flush() + + +class BandList(list): + def __init__(self, source): + self.source = source + list.__init__(self) + + def __iter__(self): + for idx in range(1, len(self) + 1): + yield GDALBand(self.source, idx) + + def __len__(self): + return capi.get_ds_raster_count(self.source._ptr) + + def __getitem__(self, index): + try: + return GDALBand(self.source, index + 1) + except GDALException: + raise GDALException('Unable to get band index %d' % index) diff --git a/django/contrib/gis/gdal/raster/source.py b/django/contrib/gis/gdal/raster/source.py index e56b442fee9..3b4eb183b1c 100644 --- a/django/contrib/gis/gdal/raster/source.py +++ b/django/contrib/gis/gdal/raster/source.py @@ -6,7 +6,7 @@ from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.driver import Driver from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.prototypes import raster as capi -from django.contrib.gis.gdal.raster.band import GDALBand +from django.contrib.gis.gdal.raster.band import BandList from django.contrib.gis.gdal.raster.const import GDAL_RESAMPLE_ALGORITHMS from django.contrib.gis.gdal.srs import SpatialReference, SRSException from django.contrib.gis.geometry.regex import json_regex @@ -15,7 +15,6 @@ from django.utils.encoding import ( force_bytes, force_text, python_2_unicode_compatible, ) from django.utils.functional import cached_property -from django.utils.six.moves import range class TransformPoint(list): @@ -108,9 +107,10 @@ class GDALRaster(GDALBase): # Set band data if provided for i, band_input in enumerate(ds_input.get('bands', [])): - self.bands[i].data(band_input['data']) + band = self.bands[i] + band.data(band_input['data']) if 'nodata_value' in band_input: - self.bands[i].nodata_value = band_input['nodata_value'] + band.nodata_value = band_input['nodata_value'] # Set SRID self.srs = ds_input.get('srid') @@ -273,15 +273,9 @@ class GDALRaster(GDALBase): return xmin, ymin, xmax, ymax - @cached_property + @property def bands(self): - """ - Returns the bands of this raster as a list of GDALBand instances. - """ - bands = [] - for idx in range(1, capi.get_ds_raster_count(self._ptr) + 1): - bands.append(GDALBand(self, idx)) - return bands + return BandList(self) def warp(self, ds_input, resampling='NearestNeighbour', max_error=0.0): """ diff --git a/tests/gis_tests/rasterapp/test_rasterfield.py b/tests/gis_tests/rasterapp/test_rasterfield.py index d1c015f499e..03691ef0b39 100644 --- a/tests/gis_tests/rasterapp/test_rasterfield.py +++ b/tests/gis_tests/rasterapp/test_rasterfield.py @@ -23,6 +23,11 @@ class RasterFieldTest(TransactionTestCase): r.refresh_from_db() self.assertIsNone(r.rast) + def test_access_band_data_directly_from_queryset(self): + RasterModel.objects.create(rast=JSON_RASTER) + qs = RasterModel.objects.all() + qs[0].rast.bands[0].data() + def test_model_creation(self): """ Test RasterField through a test model.