From 94b95c7b2c8bec6b14d9fe90d6d914b42cecf096 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 20 Nov 2019 20:18:52 +0800 Subject: [PATCH] add dataset prototype and an example for ljspeech --- data/collate.py | 71 +++++++++++++++ data/dataloader.py | 83 ++++++++++++++++++ data/dataset.py | 25 ++++++ data/ljspeech.py | 105 +++++++++++++++++++++++ data/sampler.py | 209 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 493 insertions(+) create mode 100644 data/collate.py create mode 100644 data/dataloader.py create mode 100644 data/dataset.py create mode 100644 data/ljspeech.py create mode 100644 data/sampler.py diff --git a/data/collate.py b/data/collate.py new file mode 100644 index 0000000..bdd582c --- /dev/null +++ b/data/collate.py @@ -0,0 +1,71 @@ +""" +functions to make batch for arrays which satisfy some conditions. +""" +import numpy as np + +def text_collate(minibatch): + """ + minibatch: List[Example] + Example: ndarray, shape(T,), dtype: int64 + """ + peek_example = minibatch[0] + assert len(peek_example.shape) == 1, "text example is an 1D tensor" + + lengths = [example.shape[0] for example in minibatch] # assume (channel, n_samples) or (n_samples, ) + max_len = np.max(lengths) + + batch = [] + for example in minibatch: + pad_len = max_len - example.shape[0] + batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=0)) + + return np.array(batch, dtype=np.int64) + +def wav_collate(minibatch): + """ + minibatch: List[Example] + Example: ndarray, shape(C, T) for multi-channel wav, shape(T,) for mono-channel wav, dtype: float32 + """ + peek_example = minibatch[0] + if len(peek_example.shape) == 1: + mono_channel = True + elif len(peek_example.shape) == 2: + mono_channel = False + + lengths = [example.shape[-1] for example in minibatch] # assume (channel, n_samples) or (n_samples, ) + max_len = np.max(lengths) + + batch = [] + for example in minibatch: + pad_len = max_len - example.shape[-1] + if mono_channel: + batch.append(np.pad(example, [(0, pad_len)], mode='constant', constant_values=0.)) + else: + batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=0.)) # what about PCM, no + + return np.array(batch, dtype=np.float32) + +def spec_collate(minibatch): + """ + minibatch: List[Example] + Example: ndarray, shape(C, F, T) for multi-channel spectrogram, shape(F, T) for mono-channel spectrogram, dtype: float32 + """ + # assume (F, T) or (C, F, T) + peek_example = minibatch[0] + if len(peek_example.shape) == 2: + mono_channel = True + elif len(peek_example.shape) == 3: + mono_channel = False + + lengths = [example.shape[-1] for example in minibatch] # assume (channel, F, n_frame) or (F, n_frame) + max_len = np.max(lengths) + + batch = [] + for example in minibatch: + pad_len = max_len - example.shape[-1] + if mono_channel: + batch.append(np.pad(example, [(0, 0), (0, pad_len)], mode='constant', constant_values=0.)) + else: + batch.append(np.pad(example, [(0, 0), (0, 0), (0, pad_len)], mode='constant', constant_values=0.)) # what about PCM, no + + return np.array(batch, dtype=np.float32) \ No newline at end of file diff --git a/data/dataloader.py b/data/dataloader.py new file mode 100644 index 0000000..231322c --- /dev/null +++ b/data/dataloader.py @@ -0,0 +1,83 @@ +from sampler import SequentialSampler, RandomSampler, BatchSampler + +class DataLoader(object): + def __init__(self, dataset, batch_size=1, collate_fn = lambda x: x, + sampler=None, shuffle=False, batch_sampler=None, drop_last=False): + self.dataset = dataset + self.collate_fn = collate_fn + + if batch_sampler is not None: + # auto_collation with custom batch_sampler + if batch_size != 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler option is mutually exclusive ' + 'with batch_size, shuffle, sampler, and ' + 'drop_last') + batch_size = None + drop_last = False + elif batch_size is None: + # no auto_collation + if shuffle or drop_last: + raise ValueError('batch_size=None option disables auto-batching ' + 'and is mutually exclusive with ' + 'shuffle, and drop_last') + + if sampler is None: # give default samplers + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + + if batch_size is not None and batch_sampler is None: + # auto_collation without custom batch_sampler + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.batch_size = batch_size + self.drop_last = drop_last + self.sampler = sampler + self.batch_sampler = batch_sampler + + def __iter__(self): + return DataIterator(self) + + @property + def _auto_collation(self): + # we will auto batching + return self.batch_sampler is not None + + @property + def _index_sampler(self): + # The actual sampler used for generating indices for `_DatasetFetcher` + # (see _utils/fetch.py) to read data at each time. This would be + # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. + # We can't change `.sampler` and `.batch_sampler` attributes for BC + # reasons. + if self._auto_collation: + return self.batch_sampler + else: + return self.sampler + + def __len__(self): + return len(self._index_sampler) # with iterable-style dataset, this will error + +class DataIterator(object): + def __init__(self, loader): + self.loader = loader + self._dataset = loader.dataset + + self._index_sampler = loader._index_sampler + self._sampler_iter = iter(self._index_sampler) + + def __iter__(self): + return self + + def __next__(self): + index = self._next_index() # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size + minibatch = [self._dataset[i] for i in index] # we can abstract it, too to use dynamic batch size + minibatch = self.loader.collate_fn(minibatch) # list[Example] -> Batch + return minibatch + + def _next_index(self): + return next(self._sampler_iter) + + def __len__(self): + return len(self._index_sampler) diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000..0c6bb34 --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,25 @@ +class Dataset(object): + def __init__(self, lazy=True, stream=False): + # note that lazy and stream means two different things in our glossary + # lazy means to place preprocessing in __getitem__ + # stram means the data source is itself a stream + self.lazy = lazy + self.stream = stream + + def _load_metadata(self): + raise NotImplementedError + + def _get_example(self): + """return a Record""" + raise NotImplementedError + + @classmethod + def _prepare_metadata(self): + raise NotImplementedError + + def __getitem__(self, index): + raise NotImplementedError + + def __iter__(self): + raise NotImplementedError + diff --git a/data/ljspeech.py b/data/ljspeech.py new file mode 100644 index 0000000..b2eb56e --- /dev/null +++ b/data/ljspeech.py @@ -0,0 +1,105 @@ +from pathlib import Path +import numpy as np +import pandas as pd +import librosa +import g2p + +from sampler import SequentialSampler, RandomSampler, BatchSampler +from dataset import Dataset +from dataloader import DataLoader + +from collate import text_collate, spec_collate + +LJSPEECH_ROOT = Path("/Users/chenfeiyu/projects/LJSpeech-1.1") +class LJSpeech(Dataset): + def __init__(self, root=LJSPEECH_ROOT, lazy=True, stream=False): + super(LJSpeech, self).__init__(lazy, stream) + self.root = root + self.metadata = self._prepare_metadata() # we can do this just for luck + + if self.stream: + self.examples_generator = self._read() + + def _prepare_metadata(self): + # if pure-stream case, each _prepare_metadata returns a generator + csv_path = self.root.joinpath("metadata.csv") + metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3, + names=["fname", "raw_text", "normalized_text"]) + return metadata + + def _read(self): + for _, metadatum in self.metadata.iterrows(): + example = self._get_example(metadatum) + yield example + + def _get_example(self, metadatum): + """All the code for generating an Example from a metadatum. If you want a + different preprocessing pipeline, you can override this method. + This method may require several processor, each of which has a lot of options. + In this case, you'd better pass a composed transform and pass it to the init + method. + """ + + fname, raw_text, normalized_text = metadatum + wav_path = self.root.joinpath("wavs", fname + ".wav") + + # load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize + wav, sample_rate = librosa.load(wav_path, sr=None) # we would rather use functor to hold its parameters + trimed, _ = librosa.effects.trim(wav) + preemphasized = librosa.effects.preemphasis(trimed) + D = librosa.stft(preemphasized) + mag, phase = librosa.magphase(D) + mel = librosa.feature.melspectrogram(S=mag) + + mag = librosa.amplitude_to_db(S=mag) + mel = librosa.amplitude_to_db(S=mel) + + ref_db = 20 + max_db = 100 + mel = np.clip((mel - ref_db + max_db) / max_db, 1e-8, 1) + mel = np.clip((mag - ref_db + max_db) / max_db, 1e-8, 1) + + phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64) + return (mag, mel, phonemes) # maybe we need to implement it as a map in the future + + def __getitem__(self, index): + if self.stream: + raise ValueError("__getitem__ is invalid in stream mode") + metadatum = self.metadata.iloc[index] + example = self._get_example(metadatum) + return example + + def __iter__(self): + if self.stream: + for example in self.examples_generator: + yield example + else: + for i in range(len(self)): + yield self[i] + + def __len__(self): + if self.stream: + raise ValueError("__len__ is invalid in stream mode") + return len(self.metadata) + + +def fn(minibatch): + mag_batch = [] + mel_batch = [] + phoneme_batch = [] + for example in minibatch: + mag, mel, phoneme = example + mag_batch.append(mag) + mel_batch.append(mel) + phoneme_batch.append(phoneme) + mag_batch = spec_collate(mag_batch) + mel_batch = spec_collate(mel_batch) + phoneme_batch = text_collate(phoneme_batch) + return (mag_batch, mel_batch, phoneme_batch) + +if __name__ == "__main__": + ljspeech = LJSpeech(LJSPEECH_ROOT) + ljspeech_loader = DataLoader(ljspeech, batch_size=16, shuffle=True, collate_fn=fn) + for i, batch in enumerate(ljspeech_loader): + print(i) + diff --git a/data/sampler.py b/data/sampler.py new file mode 100644 index 0000000..ff6d5d7 --- /dev/null +++ b/data/sampler.py @@ -0,0 +1,209 @@ +""" +At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__. + +This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use it to collect examples from the dataset with __getitem__. Then collate this examples to form a batch. + +So the sampler is only responsible for generating valid indices. +""" + + +import numpy as np +import random + +class Sampler(object): + def __init__(self, data_source): + pass + + def __iter__(self): + # return a iterator of indices + # or a iterator of list[int], for BatchSampler + raise NotImplementedError + + +class SequentialSampler(Sampler): + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + def __init__(self, data_source, replacement=False, num_samples=None): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + + if not isinstance(self.replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(self.replacement)) + + if self._num_samples is not None and not replacement: + raise ValueError("With replacement=False, num_samples should not be specified, " + "since a random permutation will be performed.") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples)) + + @property + def num_samples(self): + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + if self.replacement: + return iter(np.random.randint(0, n, size=(self.num_samples,), dtype=np.int64).tolist()) + return iter(np.random.permutation(n).tolist()) + + def __len__(self): + return len(self.data_source) + + +class SubsetRandomSampler(Sampler): + r"""Samples elements randomly from a given list of indices, without replacement. + Arguments: + indices (sequence): a sequence of indices + """ + + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in np.random.permutation(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PartialyRandomizedSimilarTimeLengthSampler(Sampler): + """Partially randmoized sampler, implemented as a example sampler + 1. Sort by lengths + 2. Pick a small patch and randomize it + 3. Permutate mini-batchs + """ + + def __init__(self, lengths, batch_size=4, batch_group_size=None, + permutate=True): + _lengths = np.array(lengths, dtype=np.int64) # maybe better implement length as a sort key + self.lengths = np.sort(_lengths) + self.sorted_indices = np.argsort(_lengths) + + self.batch_size = batch_size + if batch_group_size is None: + batch_group_size = min(batch_size * 32, len(self.lengths)) + if batch_group_size % batch_size != 0: + batch_group_size -= batch_group_size % batch_size + + self.batch_group_size = batch_group_size + assert batch_group_size % batch_size == 0 + self.permutate = permutate + + def __iter__(self): + indices = np.copy(self.sorted_indices) + batch_group_size = self.batch_group_size + s, e = 0, 0 + for i in range(len(indices) // batch_group_size): + s = i * batch_group_size + e = s + batch_group_size + random.shuffle(indices[s: e]) # inplace + + # Permutate batches + if self.permutate: + perm = np.arange(len(indices[:e]) // self.batch_size) + random.shuffle(perm) + indices[:e] = indices[:e].reshape(-1, self.batch_size)[perm, :].reshape(-1) + + # Handle last elements + s += batch_group_size + #print(indices) + if s < len(indices): + random.shuffle(indices[s:]) + + return iter(indices) + + def __len__(self): + return len(self.sorted_indices) + + +class WeightedRandomSampler(Sampler): + r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). + Args: + weights (sequence) : a sequence of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + Example: + >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) + [0, 0, 0, 1, 0] + >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) + [0, 1, 4, 3, 2] + """ + + def __init__(self, weights, num_samples, replacement): + if not isinstance(num_samples, int) or num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + self.weights = np.array(weights, dtype=np.float64) + self.num_samples = num_samples + self.replacement = replacement + + def __iter__(self): + return iter(np.random.choice(len(self.weights), size=(self.num_samples, ), + replace=self.replacement, p=self.weights).tolist()) + + def __len__(self): + return self.num_samples + + +class BatchSampler(Sampler): + r"""Wraps another sampler to yield a mini-batch of indices. + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + Example: + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler, batch_size, drop_last): + if not isinstance(sampler, Sampler): + raise ValueError("sampler should be an instance of " + "Sampler, but got sampler={}" + .format(sampler)) + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size \ No newline at end of file