From 837749a32c71672f0ba09b67ff44653ce6379605 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 6 Feb 2020 15:40:04 +0800 Subject: [PATCH] update statset and datacargo's design --- parakeet/__init__.py | 2 + parakeet/data/datacargo.py | 45 +++++--- parakeet/data/dataset.py | 209 +++++++++++++++++++++++++++++++++---- 3 files changed, 220 insertions(+), 36 deletions(-) diff --git a/parakeet/__init__.py b/parakeet/__init__.py index 6c8e6b9..328cdce 100644 --- a/parakeet/__init__.py +++ b/parakeet/__init__.py @@ -1 +1,3 @@ __version__ = "0.0.0" + +from . import data, g2p, models, modules, utils diff --git a/parakeet/data/datacargo.py b/parakeet/data/datacargo.py index 1d7d8d5..2685bcc 100644 --- a/parakeet/data/datacargo.py +++ b/parakeet/data/datacargo.py @@ -1,10 +1,18 @@ from .sampler import SequentialSampler, RandomSampler, BatchSampler + class DataCargo(object): - def __init__(self, dataset, batch_size=1, sampler=None, - shuffle=False, batch_sampler=None, drop_last=False): + def __init__(self, + dataset, + batch_fn=None, + batch_size=1, + sampler=None, + shuffle=False, + batch_sampler=None, + drop_last=False): self.dataset = dataset - + self.batch_fn = batch_fn or self.dataset._batch_examples + if batch_sampler is not None: # auto_collation with custom batch_sampler if batch_size != 1 or shuffle or sampler is not None or drop_last: @@ -15,7 +23,8 @@ class DataCargo(object): drop_last = False shuffle = False elif batch_size is None: - raise ValueError('batch sampler is none. then batch size must not be none.') + raise ValueError( + 'batch sampler is none. then batch size must not be none.') elif sampler is None: if shuffle: sampler = RandomSampler(dataset) @@ -23,18 +32,20 @@ class DataCargo(object): sampler = SequentialSampler(dataset) # auto_collation without custom batch_sampler batch_sampler = BatchSampler(sampler, batch_size, drop_last) + else: + batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size self.drop_last = drop_last self.sampler = sampler self.batch_sampler = batch_sampler - + def __iter__(self): return DataIterator(self) def __call__(self): return DataIterator(self) - + @property def _auto_collation(self): # we will auto batching @@ -49,26 +60,30 @@ class DataCargo(object): def __len__(self): return len(self._index_sampler) - + + class DataIterator(object): def __init__(self, loader): self.loader = loader self._dataset = loader.dataset - + + self._batch_fn = loader.batch_fn self._index_sampler = loader._index_sampler self._sampler_iter = iter(self._index_sampler) - + def __iter__(self): return self - + def __next__(self): - index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size - minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size - minibatch = self._dataset._batch_examples(minibatch) # list[Example] -> Batch + index = self._next_index( + ) # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size + minibatch = [self._dataset[i] for i in index + ] # we can abstract it, too to use dynamic batch size + minibatch = self._batch_fn(minibatch) # list[Example] -> Batch return minibatch - + def _next_index(self): return next(self._sampler_iter) - + def __len__(self): return len(self._index_sampler) diff --git a/parakeet/data/dataset.py b/parakeet/data/dataset.py index cfec912..d9f9a1f 100644 --- a/parakeet/data/dataset.py +++ b/parakeet/data/dataset.py @@ -1,24 +1,191 @@ -class Dataset(object): - def __init__(self): - pass - - def _load_metadata(self): - raise NotImplementedError - - def _get_example(self): - """return a Record (or Example, Instance according to your glossary)""" - raise NotImplementedError - - def _batch_examples(self, minibatch): - """get a list of examples, return a batch, whose structure is the same as an example""" - raise NotImplementedError - - def _prepare_metadata(self): - raise NotImplementedError - +import six +import numpy as np + + +class DatasetMixin(object): + """standard indexing interface for dataset.""" + def __getitem__(self, index): - raise NotImplementedError - - def __iter__(self): + if isinstance(index, slice): + start, stop, step = index.indices(len(self)) + return [ + self.get_example(i) + for i in six.moves.range(start, stop, step) + ] + elif isinstance(index, (list, np.ndarray)): + return [self.get_example(i) for i in index] + else: + # assumes it an integer + return self.get_example(index) + + def get_example(self, i): raise NotImplementedError + def __len__(self): + raise NotImplementedError + + def __iter__(self): + for i in range(len(self)): + yield self.get_example(i) + + +class TransformDataset(DatasetMixin): + """Transform a dataset to another with a transform.""" + + def __init__(self, dataset, transform): + self._dataset = dataset + self._transform = transform + + def __len__(self): + return len(self._dataset) + + def get_example(self, i): + # CAUTION: only int is supported? + # CAUTION: dataset support support __getitem__ and __len__ + in_data = self._dataset[i] + return self._transform(in_data) + + +class TupleDataset(object): + def __init__(self, *datasets): + if not datasets: + raise ValueError("no datasets are given") + length = len(datasets[0]) + for i, dataset in enumerate(datasets): + if len(datasets) != length: + raise ValueError( + "all the datasets should have the same length." + "dataset {} has a different length".format(i)) + self._datasets = datasets + self._length = length + + def __getitem__(self, index): + # SOA + batches = [dataset[index] for dataset in self._datasets] + if isinstance(index, slice): + length = len(batches[0]) + # AOS + return [ + tuple([batch[i] for batch in batches]) + for i in six.moves.range(length) + ] + else: + return tuple(batches) + + def __len__(self): + return self._length + + +class DictDataset(object): + def __init__(self, **datasets): + if not datasets: + raise ValueError("no datasets are given") + length = None + for key, dataset in six.iteritems(datasets): + if length is None: + length = len(dataset) + elif len(datasets) != length: + raise ValueError( + "all the datasets should have the same length." + "dataset {} has a different length".format(key)) + self._datasets = datasets + self._length = length + + def __getitem__(self, index): + batches = { + key: dataset[index] + for key, dataset in six.iteritems(self._datasets) + } + if isinstance(index, slice): + length = len(six.next(six.itervalues(batches))) + return [{key: batch[i] + for key, batch in six.iteritems(batches)} + for i in six.moves.range(length)] + else: + return batches + + +class SliceDataset(DatasetMixin): + def __init__(self, dataset, start, finish, order=None): + if start < 0 or finish > len(dataset): + raise ValueError("subset overruns the dataset.") + self._dataset = dataset + self._start = start + self._finish = finish + self._size = finish - start + + if order is not None and len(order) != len(dataset): + raise ValueError( + "order should have the same length as the dataset" + "len(order) = {} which does not euqals len(dataset) = {} ". + format(len(order), len(dataset))) + self._order = order + + def len(self): + return self._size + + def get_example(self, i): + if i >= 0: + if i >= self._size: + raise IndexError('dataset index out of range') + index = self._start + i + else: + if i < -self._size: + raise IndexError('dataset index out of range') + index = self._finish + i + + if self._order is not None: + index = self._order[index] + return self._dataset[index] + + +class SubsetDataset(DatasetMixin): + def __init__(self, dataset, indices): + self._dataset = dataset + if len(indices) > len(dataset): + raise ValueError("subset's size larger that dataset's size!") + self._indices = indices + self._size = len(indices) + + def __len__(self): + return self._size + + def get_example(self, i): + index = self._indices[i] + return self._dataset[index] + + +class FilterDataset(DatasetMixin): + def __init__(self, dataset, filter_fn): + self._dataset = dataset + self._indices = [ + i for i in range(len(dataset)) if filter_fn(dataset[i]) + ] + self._size = len(self._indices) + + def __len__(self): + return self._size + + def get_example(self, i): + index = self._indices[i] + return self._dataset[index] + + +class ChainDataset(DatasetMixin): + def __init__(self, *datasets): + self._datasets = datasets + + def __len__(self): + return sum(len(dataset) for dataset in self._datasets) + + def get_example(self, i): + if i < 0: + raise IndexError( + "ChainDataset doesnot support negative indexing.") + + for dataset in self._datasets: + if i < len(dataset): + return dataset[i] + i -= len(dataset) + + raise IndexError("dataset index out of range")