From 13ab0bd608fe427feeb479bf4058d1b6ce7fd7c4 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 10 Jun 2021 22:57:21 +0800 Subject: [PATCH] remove task-specific Datasets and add a general purpose DataTable --- parakeet/datasets/audio_dataset.py | 133 -------------------- parakeet/datasets/audio_mel_dataset.py | 161 ------------------------- parakeet/datasets/data_tabel.py | 150 +++++++++++++++++++++++ parakeet/datasets/mel_dataset.py | 132 -------------------- tests/test_data_table.py | 22 ++++ 5 files changed, 172 insertions(+), 426 deletions(-) delete mode 100644 parakeet/datasets/audio_dataset.py delete mode 100644 parakeet/datasets/audio_mel_dataset.py create mode 100644 parakeet/datasets/data_tabel.py delete mode 100644 parakeet/datasets/mel_dataset.py create mode 100644 tests/test_data_table.py diff --git a/parakeet/datasets/audio_dataset.py b/parakeet/datasets/audio_dataset.py deleted file mode 100644 index 9f6c51a..0000000 --- a/parakeet/datasets/audio_dataset.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union, Optional, Callable, Tuple -from pathlib import Path -from multiprocessing import Manager - -import numpy as np -from paddle.io import Dataset -import logging - - -class AudioDataset(Dataset): - """Dataset to load audio. - - Parameters - ---------- - root_dir : Union[Path, str] - The root of the dataset. - audio_pattern : str - A pattern to recursively find all audio files, by default "*-wave.npy" - audio_length_threshold : int, optional - The minmimal length(number of samples) of the audio, by default None - audio_load_fn : Callable, optional - Function to load the audio, which takes a Path object or str as input, - by default np.load - return_utt_id : bool, optional - Whether to include utterance indentifier in the return value of - __getitem__, by default False - use_cache : bool, optional - Whether to cache seen examples while reading, by default False - """ - - def __init__( - self, - root_dir: Union[Path, str], - audio_pattern: str="*-wave.npy", - audio_length_threshold: Optional[int]=None, - audio_load_fn: Callable=np.load, - return_utt_id: bool=False, - use_cache: bool=False, ): - # allow str and Path that contains '~' - root_dir = Path(root_dir).expanduser() - - # recursively find all of audio files that match thr pattern - audio_files = sorted(list(root_dir.rglob(audio_pattern))) - - # filter by threshold - if audio_length_threshold is not None: - audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] - idxs = [ - idx for idx in range(len(audio_files)) - if audio_lengths[idx] > audio_length_threshold - ] - if len(audio_files) != len(idxs): - logging.warning( - f"some files are filtered by audio length threshold " - f"({len(audio_files)} -> {len(idxs)}).") - audio_files = [audio_files[idx] for idx in idxs] - - # assert the number of files - assert len( - audio_files) != 0, f"Not any audio files found in {root_dir}." - - self.audio_files = audio_files - self.audio_load_fn = audio_load_fn - self.return_utt_id = return_utt_id - # TODO(chenfeiyu): better strategy to get utterance id - if ".npy" in audio_pattern: - self.utt_ids = [ - f.name.replace("-wave.npy", "") for f in audio_files - ] - else: - self.utt_ids = [f.stem for f in audio_files] - self.use_cache = use_cache - if use_cache: - # use manager to share object between multiple processes - # avoid per-reader process caching - self.manager = Manager() - self.caches = self.manager.list() - self.caches += [None for _ in range(len(audio_files))] - - def __getitem__(self, idx: int) -> Tuple[str, np.ndarray]: - """Get an example given the index. - - Parameters - ---------- - idx : int - The index. - - Returns - ------- - utt_id : str - Utterance identifier. - audio : np.ndarray - Shape (n_samples, ), the audio. - """ - if self.use_cache and self.caches[idx] is not None: - return self.caches[idx] - - utt_id = self.utt_ids[idx] - audio = self.audio_load_fn(self.audio_files[idx]) - - if self.return_utt_id: - items = utt_id, audio - else: - items = audio - - if self.use_cache: - self.caches[idx] = items - - return items - - def __len__(self) -> int: - """Returns the size of the dataset. - - Returns - ------- - int - The length of the dataset - """ - return len(self.audio_files) diff --git a/parakeet/datasets/audio_mel_dataset.py b/parakeet/datasets/audio_mel_dataset.py deleted file mode 100644 index c3fd00f..0000000 --- a/parakeet/datasets/audio_mel_dataset.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union, Optional, Callable, Tuple, Any -from pathlib import Path -from multiprocessing import Manager - -import numpy as np -from paddle.io import Dataset -import logging - - -class AudioMelDataset(Dataset): - """Dataset to laod audio and mel dataset. - - Parameters - ---------- - root_dir : Union[Path, str] - The root of the dataset. - audio_pattern : str, optional - A pattern to recursively find all audio files, by default - "*-wave.npy" - mel_pattern : str, optional - A pattern to recursively find all mel feature files, by default - "*-mel.npy" - audio_load_fn : Callable, optional - Function to load the audio, which takes a Path object or str as - input, by default np.load - mel_load_fn : Callable, optional - Function to load the mel features, which takes a Path object or - str as input, by default np.load - audio_length_threshold : Optional[int], optional - The minmimal length(number of samples) of the audio, by default None - mel_length_threshold : Optional[int], optional - The minmimal length(number of frames) of the audio, by default None - return_utt_id : bool, optional - Whether to include utterance indentifier in the return value of - __getitem__, by default False - use_cache : bool, optional - Whether to cache seen examples while reading, by default False - """ - - def __init__(self, - root_dir: Union[Path, str], - audio_pattern: str="*-wave.npy", - mel_pattern: str="*-mel.npy", - audio_load_fn: Callable=np.load, - mel_load_fn: Callable=np.load, - audio_length_threshold: Optional[int]=None, - mel_length_threshold: Optional[int]=None, - return_utt_id: bool=False, - use_cache: bool=False): - root_dir = Path(root_dir).expanduser() - # find all of audio and mel files - audio_files = sorted(list(root_dir.rglob(audio_pattern))) - mel_files = sorted(list(root_dir.rglob(mel_pattern))) - - # filter by threshold - if audio_length_threshold is not None: - audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] - idxs = [ - idx for idx in range(len(audio_files)) - if audio_lengths[idx] > audio_length_threshold - ] - if len(audio_files) != len(idxs): - logging.warning( - f"Some files are filtered by audio length threshold " - f"({len(audio_files)} -> {len(idxs)}).") - audio_files = [audio_files[idx] for idx in idxs] - mel_files = [mel_files[idx] for idx in idxs] - if mel_length_threshold is not None: - mel_lengths = [mel_load_fn(f).shape[1] for f in mel_files] - idxs = [ - idx for idx in range(len(mel_files)) - if mel_lengths[idx] > mel_length_threshold - ] - if len(mel_files) != len(idxs): - logging.warning( - f"Some files are filtered by mel length threshold " - f"({len(mel_files)} -> {len(idxs)}).") - audio_files = [audio_files[idx] for idx in idxs] - mel_files = [mel_files[idx] for idx in idxs] - - # assert the number of files - assert len( - audio_files) != 0, f"Not found any audio files in {root_dir}." - assert len(audio_files) == len(mel_files), \ - (f"Number of audio and mel files are different " - f"({len(audio_files)} vs {len(mel_files)}).") - - self.audio_files = audio_files - self.audio_load_fn = audio_load_fn - self.mel_load_fn = mel_load_fn - self.mel_files = mel_files - if ".npy" in audio_pattern: - self.utt_ids = [ - f.name.replace("-wave.npy", "") for f in audio_files - ] - else: - self.utt_ids = [f.stem for f in audio_files] - self.return_utt_id = return_utt_id - self.use_cache = use_cache - if use_cache: - self.manager = Manager() - self.caches = self.manager.list() - self.caches += [None for _ in range(len(audio_files))] - - def __getitem__(self, idx: int) -> Tuple: - """Get an example given the index. - - Parameters - ---------- - idx : int - The index of the example. - - Returns - ------- - utt_id : str - Utterance identifier. - audio : np.ndarray - Shape (n_samples, ), the audio. - mel: np.ndarray - Shape (n_mels, n_frames), the mel spectrogram. - """ - if self.use_cache and self.caches[idx] is not None: - return self.caches[idx] - - utt_id = self.utt_ids[idx] - audio = self.audio_load_fn(self.audio_files[idx]) - mel = self.mel_load_fn(self.mel_files[idx]) - - if self.return_utt_id: - items = utt_id, audio, mel - else: - items = audio, mel - - if self.use_cache: - self.caches[idx] = items - - return items - - def __len__(self): - """Returns the size of the dataset. - - Returns - ------- - int - The length of the dataset - """ - return len(self.audio_files) diff --git a/parakeet/datasets/data_tabel.py b/parakeet/datasets/data_tabel.py new file mode 100644 index 0000000..75d075d --- /dev/null +++ b/parakeet/datasets/data_tabel.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Optional, Callable, Tuple, List, Dict, Any +from pathlib import Path +from multiprocessing import Manager + +import numpy as np +from paddle.io import Dataset + + +class DataTable(Dataset): + def __init__(self, + data: List[Dict[str, Any]], + fields: List[str]=None, + converters: Dict[str, Callable]=None, + use_cache: bool=False): + """Dataset to load and convert data for general purpose. + + Parameters + ---------- + data : List[Dict[str, Any]] + Metadata, a list of meta datum, each of which is composed of + several fields + fields : List[str], optional + Fields to use, if not specified, all the fields in the data are + used, by default None + converters : Dict[str, Callable], optional + Converters used to process each field, by default None + use_cache : bool, optional + Whether to use cache, by default False + + Raises + ------ + ValueError + If there is some field that does not exist in data. + ValueError + If there is some field in converters that does not exist in fields. + """ + # metadata + self.data = data + assert len(data) > 0, "This dataset has no examples" + + # peak an example to get existing fields. + first_example = self.data[0] + fields_in_data = first_example.keys() + + # check all the requested fields exist + if fields is None: + self.fields = fields_in_data + else: + for field in fields: + if field not in fields_in_data: + raise ValueError( + f"The requested field ({field}) is not found" + f"in the data. Fields in the data is {fields_in_data}") + self.fields = fields + + # check converters + if converters is None: + self.converters = {} + else: + for field in converters.keys(): + if field not in self.fields: + raise ValueError( + f"The converter has a non existing field ({field})") + self.converters = converters + + self.use_cache = use_cache + if use_cache: + self._initialize_cache() + + def _initialize_cache(self): + self.manager = Manager() + self.caches = self.manager.list() + self.caches += [None for _ in range(len(self))] + + def _get_metadata(self, idx: int) -> Dict[str, Any]: + """Return a meta-datum given an index.""" + return self.data[idx] + + def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]: + """Convert a meta datum to an example by applying the corresponding + converters to each fields requested. + + Parameters + ---------- + meta_datum : Dict[str, Any] + Meta datum + + Returns + ------- + Dict[str, Any] + Converted example + """ + example = {} + for field in self.fields: + converter = self.converters.get(field, None) + meta_datum_field = meta_datum[field] + if converter is not None: + converted_field = converter(meta_datum_field) + else: + converted_field = meta_datum_field + example[field] = converted_field + return example + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Get an example given an index. + + Parameters + ---------- + idx : int + Index of the example to get + + Returns + ------- + Dict[str, Any] + A converted example + """ + if self.use_cache and self.caches[idx] is not None: + return self.caches[idx] + + meta_datum = self._get_metadata(idx) + example = self._convert(meta_datum) + + if self.use_cache: + self.caches[idx] = example + + return example + + def __len__(self) -> int: + """Returns the size of the dataset. + + Returns + ------- + int + The length of the dataset + """ + return len(self.data) diff --git a/parakeet/datasets/mel_dataset.py b/parakeet/datasets/mel_dataset.py deleted file mode 100644 index 038654c..0000000 --- a/parakeet/datasets/mel_dataset.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union, Optional, Callable, Tuple -from pathlib import Path -from multiprocessing import Manager - -import numpy as np -from paddle.io import Dataset -import logging - - -class MelDataset(Dataset): - """Dataset to load mel-spectrograms. - - Parameters - ---------- - root_dir : Union[Path, str] - The root of the dataset. - mel_pattern : str, optional - A pattern to recursively find all mel feature files, by default - "*-feats.npy" - mel_length_threshold : Optional[int], optional - The minmimal length(number of frames) of the audio, by default None - mel_load_fn : Callable, optional - Function to load the audio, which takes a Path object or str as input, - by default np.load - return_utt_id : bool, optional - Whether to include utterance indentifier in the return value of - __getitem__, by default False - use_cahce : bool, optional - Whether to cache seen examples while reading, by default False - """ - - def __init__( - self, - root_dir: Union[Path, str], - mel_pattern: str="*-feats.npy", - mel_length_threshold: Optional[int]=None, - mel_load_fn: Callable=np.load, - return_utt_id: bool=False, - use_cahce: bool=False, ): - # allow str and Path that contains '~' - root_dir = Path(root_dir).expanduser() - - # find all of the mel files - mel_files = sorted(list(root_dir.rglob(mel_pattern))) - - # filter by threshold - if mel_length_threshold is not None: - mel_lengths = [mel_load_fn(f).shape[1] for f in mel_files] - idxs = [ - idx for idx in range(len(mel_files)) - if mel_lengths[idx] > mel_length_threshold - ] - if len(mel_files) != len(idxs): - logging.warning( - f"Some files are filtered by mel length threshold " - f"({len(mel_files)} -> {len(idxs)}).") - mel_files = [mel_files[idx] for idx in idxs] - - # assert the number of files - assert len(mel_files) != 0, f"Not found any mel files in {root_dir}." - - self.mel_files = mel_files - self.mel_load_fn = mel_load_fn - - # TODO(chenfeiyu): better strategy to get utterance id - if ".npy" in mel_pattern: - self.utt_ids = [ - f.name.replace("-feats.npy", "") for f in mel_files - ] - else: - self.utt_ids = [f.stem for f in mel_files] - self.return_utt_id = return_utt_id - self.use_cache = use_cahce - if use_cahce: - self.manager = Manager() - self.caches = self.manager.list() - self.caches += [None for _ in range(len(mel_files))] - - def __getitem__(self, idx): - """Get an example given the index. - - Parameters - ---------- - idx : int - The index - - Returns - ------- - utt_id : str - Utterance identifier. - audio : np.ndarray - Shape (n_mels, n_frames), the mel spectrogram. - """ - if self.use_cache and self.caches[idx] is not None: - return self.caches[idx] - - utt_id = self.utt_ids[idx] - mel = self.mel_load_fn(self.mel_files[idx]) - - if self.return_utt_id: - items = utt_id, mel - else: - items = mel - - if self.use_cache: - self.caches[idx] = items - - return items - - def __len__(self): - """Returns the size of the dataset. - - Returns - ------- - int - The length of the dataset - """ - return len(self.mel_files) diff --git a/tests/test_data_table.py b/tests/test_data_table.py new file mode 100644 index 0000000..aca0605 --- /dev/null +++ b/tests/test_data_table.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parakeet.datasets.data_tabel import DataTable + + +def test_audio_dataset(): + metadata = [{'name': 'Sonic', 'v': 1000}, {'name': 'Prestol', 'v': 2000}] + converters = {'v': lambda x: x / 1000} + dataset = DataTable(metadata, fields=['v'], converters=converters) + assert dataset[0] == {'v': 1.0}