2020-02-26 21:03:51 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-02-06 15:40:04 +08:00
import six
2020-10-10 15:51:54 +08:00
import paddle
from paddle . io import Dataset
2020-02-06 15:40:04 +08:00
2020-12-09 15:58:39 +08:00
__all__ = [
" split " , " TransformDataset " , " CacheDataset " , " TupleDataset " ,
" DictDataset " , " SliceDataset " , " SubsetDataset " , " FilterDataset " ,
" ChainDataset " ,
]
2020-02-06 15:40:04 +08:00
2020-10-10 15:51:54 +08:00
def split ( dataset , first_size ) :
""" A utility function to split a dataset into two datasets. """
first = SliceDataset ( dataset , 0 , first_size )
second = SliceDataset ( dataset , first_size , len ( dataset ) )
return first , second
2020-02-06 15:40:04 +08:00
2020-10-10 15:51:54 +08:00
class TransformDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , dataset , transform ) :
2020-03-09 11:06:28 +08:00
""" Dataset which is transformed from another with a transform.
Args :
2020-10-10 15:51:54 +08:00
dataset ( Dataset ) : the base dataset .
2020-03-09 11:06:28 +08:00
transform ( callable ) : the transform which takes an example of the base dataset as parameter and return a new example .
"""
2020-02-06 15:40:04 +08:00
self . _dataset = dataset
self . _transform = transform
def __len__ ( self ) :
return len ( self . _dataset )
2020-10-10 15:51:54 +08:00
def __getitem__ ( self , i ) :
2020-02-06 15:40:04 +08:00
in_data = self . _dataset [ i ]
return self . _transform ( in_data )
2020-10-10 15:51:54 +08:00
class CacheDataset ( Dataset ) :
2020-03-06 10:55:42 +08:00
def __init__ ( self , dataset ) :
2020-03-09 11:06:28 +08:00
""" A lazy cache of the base dataset.
Args :
2020-10-10 15:51:54 +08:00
dataset ( Dataset ) : the base dataset to cache .
2020-03-09 11:06:28 +08:00
"""
2020-03-06 10:55:42 +08:00
self . _dataset = dataset
self . _cache = dict ( )
def __len__ ( self ) :
return len ( self . _dataset )
2020-10-10 15:51:54 +08:00
def __getitem__ ( self , i ) :
2020-03-06 10:55:42 +08:00
if not i in self . _cache :
self . _cache [ i ] = self . _dataset [ i ]
return self . _cache [ i ]
2020-10-10 15:51:54 +08:00
class TupleDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , * datasets ) :
2020-03-09 11:06:28 +08:00
""" A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets.
Args :
2020-10-10 15:51:54 +08:00
datasets : tuple [ Dataset ] , the constituent datasets .
2020-03-09 11:06:28 +08:00
"""
2020-02-06 15:40:04 +08:00
if not datasets :
raise ValueError ( " no datasets are given " )
length = len ( datasets [ 0 ] )
for i , dataset in enumerate ( datasets ) :
2020-10-10 15:51:54 +08:00
if len ( dataset ) != length :
2020-02-06 15:40:04 +08:00
raise ValueError (
" all the datasets should have the same length. "
" dataset {} has a different length " . format ( i ) )
self . _datasets = datasets
self . _length = length
def __getitem__ ( self , index ) :
# SOA
batches = [ dataset [ index ] for dataset in self . _datasets ]
if isinstance ( index , slice ) :
length = len ( batches [ 0 ] )
# AOS
return [
tuple ( [ batch [ i ] for batch in batches ] )
for i in six . moves . range ( length )
]
else :
return tuple ( batches )
def __len__ ( self ) :
return self . _length
2020-10-10 15:51:54 +08:00
class DictDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , * * datasets ) :
2020-10-10 15:51:54 +08:00
"""
A compound dataset made from several datasets of the same length . An
example of the ` DictDataset ` is a dict of examples from the constituent
datasets .
WARNING : paddle does not have a good support for DictDataset , because
every batch yield from a DataLoader is a list , but it cannot be a dict .
So you have to provide a collate function because you cannot use the
default one .
2020-03-09 11:06:28 +08:00
Args :
2020-10-10 15:51:54 +08:00
datasets : Dict [ Dataset ] , the constituent datasets .
2020-03-09 11:06:28 +08:00
"""
2020-02-06 15:40:04 +08:00
if not datasets :
raise ValueError ( " no datasets are given " )
length = None
for key , dataset in six . iteritems ( datasets ) :
if length is None :
length = len ( dataset )
2020-10-10 15:51:54 +08:00
elif len ( dataset ) != length :
2020-02-06 15:40:04 +08:00
raise ValueError (
" all the datasets should have the same length. "
" dataset {} has a different length " . format ( key ) )
self . _datasets = datasets
self . _length = length
def __getitem__ ( self , index ) :
batches = {
key : dataset [ index ]
for key , dataset in six . iteritems ( self . _datasets )
}
if isinstance ( index , slice ) :
length = len ( six . next ( six . itervalues ( batches ) ) )
return [ { key : batch [ i ]
for key , batch in six . iteritems ( batches ) }
for i in six . moves . range ( length ) ]
else :
return batches
2020-10-10 15:51:54 +08:00
def __len__ ( self ) :
return self . _length
2020-02-06 15:40:04 +08:00
2020-10-10 15:51:54 +08:00
class SliceDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , dataset , start , finish , order = None ) :
2020-03-09 11:06:28 +08:00
""" A Dataset which is a slice of the base dataset.
Args :
2020-10-10 15:51:54 +08:00
dataset ( Dataset ) : the base dataset .
2020-03-09 11:06:28 +08:00
start ( int ) : the start of the slice .
finish ( int ) : the end of the slice , not inclusive .
order ( List [ int ] , optional ) : the order , it is a permutation of the valid example ids of the base dataset . If ` order ` is provided , the slice is taken in ` order ` . Defaults to None .
"""
2020-02-06 15:40:04 +08:00
if start < 0 or finish > len ( dataset ) :
raise ValueError ( " subset overruns the dataset. " )
self . _dataset = dataset
self . _start = start
self . _finish = finish
self . _size = finish - start
if order is not None and len ( order ) != len ( dataset ) :
raise ValueError (
" order should have the same length as the dataset "
" len(order) = {} which does not euqals len(dataset) = {} " .
format ( len ( order ) , len ( dataset ) ) )
self . _order = order
2020-02-27 18:23:05 +08:00
def __len__ ( self ) :
2020-02-06 15:40:04 +08:00
return self . _size
2020-10-10 15:51:54 +08:00
def __getitem__ ( self , i ) :
2020-02-06 15:40:04 +08:00
if i > = 0 :
if i > = self . _size :
raise IndexError ( ' dataset index out of range ' )
index = self . _start + i
else :
if i < - self . _size :
raise IndexError ( ' dataset index out of range ' )
index = self . _finish + i
if self . _order is not None :
index = self . _order [ index ]
return self . _dataset [ index ]
2020-10-10 15:51:54 +08:00
class SubsetDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , dataset , indices ) :
2020-03-09 11:06:28 +08:00
""" A Dataset which is a subset of the base dataset.
Args :
2020-10-10 15:51:54 +08:00
dataset ( Dataset ) : the base dataset .
2020-03-09 11:06:28 +08:00
indices ( Iterable [ int ] ) : the indices of the examples to pick .
"""
2020-02-06 15:40:04 +08:00
self . _dataset = dataset
if len ( indices ) > len ( dataset ) :
raise ValueError ( " subset ' s size larger that dataset ' s size! " )
self . _indices = indices
self . _size = len ( indices )
def __len__ ( self ) :
return self . _size
2020-10-10 15:51:54 +08:00
def __getitem__ ( self , i ) :
2020-02-06 15:40:04 +08:00
index = self . _indices [ i ]
return self . _dataset [ index ]
2020-10-10 15:51:54 +08:00
class FilterDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , dataset , filter_fn ) :
2020-03-09 11:06:28 +08:00
""" A filtered dataset.
Args :
2020-10-10 15:51:54 +08:00
dataset ( Dataset ) : the base dataset .
2020-03-09 11:06:28 +08:00
filter_fn ( callable ) : a callable which takes an example of the base dataset and return a boolean .
"""
2020-02-06 15:40:04 +08:00
self . _dataset = dataset
self . _indices = [
i for i in range ( len ( dataset ) ) if filter_fn ( dataset [ i ] )
]
self . _size = len ( self . _indices )
def __len__ ( self ) :
return self . _size
2020-10-10 15:51:54 +08:00
def __getitem__ ( self , i ) :
2020-02-06 15:40:04 +08:00
index = self . _indices [ i ]
return self . _dataset [ index ]
2020-10-10 15:51:54 +08:00
class ChainDataset ( Dataset ) :
2020-02-06 15:40:04 +08:00
def __init__ ( self , * datasets ) :
2020-03-09 11:06:28 +08:00
""" A concatenation of the several datasets which the same structure.
Args :
2020-10-10 15:51:54 +08:00
datasets ( Iterable [ Dataset ] ) : datasets to concat .
2020-03-09 11:06:28 +08:00
"""
2020-02-06 15:40:04 +08:00
self . _datasets = datasets
def __len__ ( self ) :
return sum ( len ( dataset ) for dataset in self . _datasets )
2020-10-10 15:51:54 +08:00
def __getitem__ ( self , i ) :
2020-02-06 15:40:04 +08:00
if i < 0 :
2020-02-26 21:03:51 +08:00
raise IndexError ( " ChainDataset doesnot support negative indexing. " )
2020-02-06 15:40:04 +08:00
for dataset in self . _datasets :
if i < len ( dataset ) :
return dataset [ i ]
i - = len ( dataset )
raise IndexError ( " dataset index out of range " )