2020-02-26 21:03:51 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-02-06 15:40:04 +08:00
import six
import numpy as np
2020-03-05 15:08:12 +08:00
from tqdm import tqdm
2020-02-06 15:40:04 +08:00
class DatasetMixin ( object ) :
2020-03-09 11:06:28 +08:00
""" Standard indexing interface for dataset. Inherit this class to
get the indexing interface . Since it is a mixin class which does
not have an ` __init__ ` class , the subclass not need to call
` super ( ) . __init__ ( ) ` .
"""
2020-02-06 15:40:04 +08:00
2019-11-20 20:18:52 +08:00
def __getitem__ ( self , index ) :
2020-03-09 11:06:28 +08:00
""" Standard indexing interface for dataset.
Args :
index ( slice , list [ int ] , np . array or int ) : the index . if can be int , slice , list of integers , or ndarray of integers . It calls ` get_example ` to pick an example .
Returns :
Example , or List [ Example ] : If ` index ` is an interger , it returns an
example . If ` index ` is a slice , a list of intergers or an array of intergers ,
it returns a list of examples .
"""
2020-02-06 15:40:04 +08:00
if isinstance ( index , slice ) :
start , stop , step = index . indices ( len ( self ) )
return [
2020-02-26 21:03:51 +08:00
self . get_example ( i ) for i in six . moves . range ( start , stop , step )
2020-02-06 15:40:04 +08:00
]
elif isinstance ( index , ( list , np . ndarray ) ) :
return [ self . get_example ( i ) for i in index ]
else :
# assumes it an integer
return self . get_example ( index )
def get_example ( self , i ) :
2020-03-09 11:06:28 +08:00
""" Get an example from the dataset. Custom datasets should have
this method implemented .
Args :
i ( int ) : example index .
"""
2019-11-20 20:18:52 +08:00
raise NotImplementedError
2020-02-06 15:40:04 +08:00
def __len__ ( self ) :
2019-11-20 20:18:52 +08:00
raise NotImplementedError
2020-02-06 15:40:04 +08:00
def __iter__ ( self ) :
for i in range ( len ( self ) ) :
yield self . get_example ( i )
class TransformDataset ( DatasetMixin ) :
def __init__ ( self , dataset , transform ) :
2020-03-09 11:06:28 +08:00
""" Dataset which is transformed from another with a transform.
Args :
dataset ( DatasetMixin ) : the base dataset .
transform ( callable ) : the transform which takes an example of the base dataset as parameter and return a new example .
"""
2020-02-06 15:40:04 +08:00
self . _dataset = dataset
self . _transform = transform
def __len__ ( self ) :
return len ( self . _dataset )
def get_example ( self , i ) :
in_data = self . _dataset [ i ]
return self . _transform ( in_data )
2020-03-06 10:55:42 +08:00
class CacheDataset ( DatasetMixin ) :
def __init__ ( self , dataset ) :
2020-03-09 11:06:28 +08:00
""" A lazy cache of the base dataset.
Args :
dataset ( DatasetMixin ) : the base dataset to cache .
"""
2020-03-06 10:55:42 +08:00
self . _dataset = dataset
self . _cache = dict ( )
def __len__ ( self ) :
return len ( self . _dataset )
def get_example ( self , i ) :
if not i in self . _cache :
self . _cache [ i ] = self . _dataset [ i ]
return self . _cache [ i ]
2020-02-06 15:40:04 +08:00
class TupleDataset ( object ) :
def __init__ ( self , * datasets ) :
2020-03-09 11:06:28 +08:00
""" A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets.
Args :
datasets : tuple [ DatasetMixin ] , the constituent datasets .
"""
2020-02-06 15:40:04 +08:00
if not datasets :
raise ValueError ( " no datasets are given " )
length = len ( datasets [ 0 ] )
for i , dataset in enumerate ( datasets ) :
if len ( datasets ) != length :
raise ValueError (
" all the datasets should have the same length. "
" dataset {} has a different length " . format ( i ) )
self . _datasets = datasets
self . _length = length
def __getitem__ ( self , index ) :
# SOA
batches = [ dataset [ index ] for dataset in self . _datasets ]
if isinstance ( index , slice ) :
length = len ( batches [ 0 ] )
# AOS
return [
tuple ( [ batch [ i ] for batch in batches ] )
for i in six . moves . range ( length )
]
else :
return tuple ( batches )
def __len__ ( self ) :
return self . _length
class DictDataset ( object ) :
def __init__ ( self , * * datasets ) :
2020-03-09 11:06:28 +08:00
""" A compound dataset made from several datasets of the same length. An example of the `DictDataset` is a dict of examples from the constituent datasets.
Args :
datasets : Dict [ DatasetMixin ] , the constituent datasets .
"""
2020-02-06 15:40:04 +08:00
if not datasets :
raise ValueError ( " no datasets are given " )
length = None
for key , dataset in six . iteritems ( datasets ) :
if length is None :
length = len ( dataset )
elif len ( datasets ) != length :
raise ValueError (
" all the datasets should have the same length. "
" dataset {} has a different length " . format ( key ) )
self . _datasets = datasets
self . _length = length
def __getitem__ ( self , index ) :
batches = {
key : dataset [ index ]
for key , dataset in six . iteritems ( self . _datasets )
}
if isinstance ( index , slice ) :
length = len ( six . next ( six . itervalues ( batches ) ) )
return [ { key : batch [ i ]
for key , batch in six . iteritems ( batches ) }
for i in six . moves . range ( length ) ]
else :
return batches
class SliceDataset ( DatasetMixin ) :
def __init__ ( self , dataset , start , finish , order = None ) :
2020-03-09 11:06:28 +08:00
""" A Dataset which is a slice of the base dataset.
Args :
dataset ( DatasetMixin ) : the base dataset .
start ( int ) : the start of the slice .
finish ( int ) : the end of the slice , not inclusive .
order ( List [ int ] , optional ) : the order , it is a permutation of the valid example ids of the base dataset . If ` order ` is provided , the slice is taken in ` order ` . Defaults to None .
"""
2020-02-06 15:40:04 +08:00
if start < 0 or finish > len ( dataset ) :
raise ValueError ( " subset overruns the dataset. " )
self . _dataset = dataset
self . _start = start
self . _finish = finish
self . _size = finish - start
if order is not None and len ( order ) != len ( dataset ) :
raise ValueError (
" order should have the same length as the dataset "
" len(order) = {} which does not euqals len(dataset) = {} " .
format ( len ( order ) , len ( dataset ) ) )
self . _order = order
2020-02-27 18:23:05 +08:00
def __len__ ( self ) :
2020-02-06 15:40:04 +08:00
return self . _size
def get_example ( self , i ) :
if i > = 0 :
if i > = self . _size :
raise IndexError ( ' dataset index out of range ' )
index = self . _start + i
else :
if i < - self . _size :
raise IndexError ( ' dataset index out of range ' )
index = self . _finish + i
if self . _order is not None :
index = self . _order [ index ]
return self . _dataset [ index ]
class SubsetDataset ( DatasetMixin ) :
def __init__ ( self , dataset , indices ) :
2020-03-09 11:06:28 +08:00
""" A Dataset which is a subset of the base dataset.
Args :
dataset ( DatasetMixin ) : the base dataset .
indices ( Iterable [ int ] ) : the indices of the examples to pick .
"""
2020-02-06 15:40:04 +08:00
self . _dataset = dataset
if len ( indices ) > len ( dataset ) :
raise ValueError ( " subset ' s size larger that dataset ' s size! " )
self . _indices = indices
self . _size = len ( indices )
def __len__ ( self ) :
return self . _size
def get_example ( self , i ) :
index = self . _indices [ i ]
return self . _dataset [ index ]
class FilterDataset ( DatasetMixin ) :
def __init__ ( self , dataset , filter_fn ) :
2020-03-09 11:06:28 +08:00
""" A filtered dataset.
Args :
dataset ( DatasetMixin ) : the base dataset .
filter_fn ( callable ) : a callable which takes an example of the base dataset and return a boolean .
"""
2020-02-06 15:40:04 +08:00
self . _dataset = dataset
self . _indices = [
i for i in range ( len ( dataset ) ) if filter_fn ( dataset [ i ] )
]
self . _size = len ( self . _indices )
def __len__ ( self ) :
return self . _size
def get_example ( self , i ) :
index = self . _indices [ i ]
return self . _dataset [ index ]
class ChainDataset ( DatasetMixin ) :
def __init__ ( self , * datasets ) :
2020-03-09 11:06:28 +08:00
""" A concatenation of the several datasets which the same structure.
Args :
datasets ( Iterable [ DatasetMixin ] ) : datasets to concat .
"""
2020-02-06 15:40:04 +08:00
self . _datasets = datasets
def __len__ ( self ) :
return sum ( len ( dataset ) for dataset in self . _datasets )
def get_example ( self , i ) :
if i < 0 :
2020-02-26 21:03:51 +08:00
raise IndexError ( " ChainDataset doesnot support negative indexing. " )
2020-02-06 15:40:04 +08:00
for dataset in self . _datasets :
if i < len ( dataset ) :
return dataset [ i ]
i - = len ( dataset )
raise IndexError ( " dataset index out of range " )