Fixed #25072 -- Prevented GDALRaster memory to be uncollectable

Setting GDALRaster.bands as a cached property was creating a circular
reference with objects having __del__ methods, which means the memory
could never be freed.
Thanks Daniel Wiesmann for the report and test, and Tim Graham for the review.
This commit is contained in:
Claude Paroz 2015-07-09 10:52:11 +02:00
parent 0b02ce54cf
commit d72f8862cb
3 changed files with 32 additions and 12 deletions

View File

@ -1,10 +1,12 @@
from ctypes import byref, c_int from ctypes import byref, c_int
from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import raster as capi from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.shortcuts import numpy from django.contrib.gis.shortcuts import numpy
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.six.moves import range
from .const import GDAL_INTEGER_TYPES, GDAL_PIXEL_TYPES, GDAL_TO_CTYPES from .const import GDAL_INTEGER_TYPES, GDAL_PIXEL_TYPES, GDAL_TO_CTYPES
@ -148,3 +150,22 @@ class GDALBand(GDALBase):
return list(data_array) return list(data_array)
else: else:
self.source._flush() self.source._flush()
class BandList(list):
def __init__(self, source):
self.source = source
list.__init__(self)
def __iter__(self):
for idx in range(1, len(self) + 1):
yield GDALBand(self.source, idx)
def __len__(self):
return capi.get_ds_raster_count(self.source._ptr)
def __getitem__(self, index):
try:
return GDALBand(self.source, index + 1)
except GDALException:
raise GDALException('Unable to get band index %d' % index)

View File

@ -6,7 +6,7 @@ from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.driver import Driver from django.contrib.gis.gdal.driver import Driver
from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import raster as capi from django.contrib.gis.gdal.prototypes import raster as capi
from django.contrib.gis.gdal.raster.band import GDALBand from django.contrib.gis.gdal.raster.band import BandList
from django.contrib.gis.gdal.raster.const import GDAL_RESAMPLE_ALGORITHMS from django.contrib.gis.gdal.raster.const import GDAL_RESAMPLE_ALGORITHMS
from django.contrib.gis.gdal.srs import SpatialReference, SRSException from django.contrib.gis.gdal.srs import SpatialReference, SRSException
from django.contrib.gis.geometry.regex import json_regex from django.contrib.gis.geometry.regex import json_regex
@ -15,7 +15,6 @@ from django.utils.encoding import (
force_bytes, force_text, python_2_unicode_compatible, force_bytes, force_text, python_2_unicode_compatible,
) )
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.six.moves import range
class TransformPoint(list): class TransformPoint(list):
@ -108,9 +107,10 @@ class GDALRaster(GDALBase):
# Set band data if provided # Set band data if provided
for i, band_input in enumerate(ds_input.get('bands', [])): for i, band_input in enumerate(ds_input.get('bands', [])):
self.bands[i].data(band_input['data']) band = self.bands[i]
band.data(band_input['data'])
if 'nodata_value' in band_input: if 'nodata_value' in band_input:
self.bands[i].nodata_value = band_input['nodata_value'] band.nodata_value = band_input['nodata_value']
# Set SRID # Set SRID
self.srs = ds_input.get('srid') self.srs = ds_input.get('srid')
@ -273,15 +273,9 @@ class GDALRaster(GDALBase):
return xmin, ymin, xmax, ymax return xmin, ymin, xmax, ymax
@cached_property @property
def bands(self): def bands(self):
""" return BandList(self)
Returns the bands of this raster as a list of GDALBand instances.
"""
bands = []
for idx in range(1, capi.get_ds_raster_count(self._ptr) + 1):
bands.append(GDALBand(self, idx))
return bands
def warp(self, ds_input, resampling='NearestNeighbour', max_error=0.0): def warp(self, ds_input, resampling='NearestNeighbour', max_error=0.0):
""" """

View File

@ -23,6 +23,11 @@ class RasterFieldTest(TransactionTestCase):
r.refresh_from_db() r.refresh_from_db()
self.assertIsNone(r.rast) self.assertIsNone(r.rast)
def test_access_band_data_directly_from_queryset(self):
RasterModel.objects.create(rast=JSON_RASTER)
qs = RasterModel.objects.all()
qs[0].rast.bands[0].data()
def test_model_creation(self): def test_model_creation(self):
""" """
Test RasterField through a test model. Test RasterField through a test model.