remove task-specific Datasets and add a general purpose DataTable

This commit is contained in:
chenfeiyu 2021-06-10 22:57:21 +08:00
parent 3bf2e71734
commit 13ab0bd608
5 changed files with 172 additions and 426 deletions

View File

@ -1,133 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Callable, Tuple
from pathlib import Path
from multiprocessing import Manager
import numpy as np
from paddle.io import Dataset
import logging
class AudioDataset(Dataset):
"""Dataset to load audio.
Parameters
----------
root_dir : Union[Path, str]
The root of the dataset.
audio_pattern : str
A pattern to recursively find all audio files, by default "*-wave.npy"
audio_length_threshold : int, optional
The minmimal length(number of samples) of the audio, by default None
audio_load_fn : Callable, optional
Function to load the audio, which takes a Path object or str as input,
by default np.load
return_utt_id : bool, optional
Whether to include utterance indentifier in the return value of
__getitem__, by default False
use_cache : bool, optional
Whether to cache seen examples while reading, by default False
"""
def __init__(
self,
root_dir: Union[Path, str],
audio_pattern: str="*-wave.npy",
audio_length_threshold: Optional[int]=None,
audio_load_fn: Callable=np.load,
return_utt_id: bool=False,
use_cache: bool=False, ):
# allow str and Path that contains '~'
root_dir = Path(root_dir).expanduser()
# recursively find all of audio files that match thr pattern
audio_files = sorted(list(root_dir.rglob(audio_pattern)))
# filter by threshold
if audio_length_threshold is not None:
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
idxs = [
idx for idx in range(len(audio_files))
if audio_lengths[idx] > audio_length_threshold
]
if len(audio_files) != len(idxs):
logging.warning(
f"some files are filtered by audio length threshold "
f"({len(audio_files)} -> {len(idxs)}).")
audio_files = [audio_files[idx] for idx in idxs]
# assert the number of files
assert len(
audio_files) != 0, f"Not any audio files found in {root_dir}."
self.audio_files = audio_files
self.audio_load_fn = audio_load_fn
self.return_utt_id = return_utt_id
# TODO(chenfeiyu): better strategy to get utterance id
if ".npy" in audio_pattern:
self.utt_ids = [
f.name.replace("-wave.npy", "") for f in audio_files
]
else:
self.utt_ids = [f.stem for f in audio_files]
self.use_cache = use_cache
if use_cache:
# use manager to share object between multiple processes
# avoid per-reader process caching
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [None for _ in range(len(audio_files))]
def __getitem__(self, idx: int) -> Tuple[str, np.ndarray]:
"""Get an example given the index.
Parameters
----------
idx : int
The index.
Returns
-------
utt_id : str
Utterance identifier.
audio : np.ndarray
Shape (n_samples, ), the audio.
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
utt_id = self.utt_ids[idx]
audio = self.audio_load_fn(self.audio_files[idx])
if self.return_utt_id:
items = utt_id, audio
else:
items = audio
if self.use_cache:
self.caches[idx] = items
return items
def __len__(self) -> int:
"""Returns the size of the dataset.
Returns
-------
int
The length of the dataset
"""
return len(self.audio_files)

View File

@ -1,161 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Callable, Tuple, Any
from pathlib import Path
from multiprocessing import Manager
import numpy as np
from paddle.io import Dataset
import logging
class AudioMelDataset(Dataset):
"""Dataset to laod audio and mel dataset.
Parameters
----------
root_dir : Union[Path, str]
The root of the dataset.
audio_pattern : str, optional
A pattern to recursively find all audio files, by default
"*-wave.npy"
mel_pattern : str, optional
A pattern to recursively find all mel feature files, by default
"*-mel.npy"
audio_load_fn : Callable, optional
Function to load the audio, which takes a Path object or str as
input, by default np.load
mel_load_fn : Callable, optional
Function to load the mel features, which takes a Path object or
str as input, by default np.load
audio_length_threshold : Optional[int], optional
The minmimal length(number of samples) of the audio, by default None
mel_length_threshold : Optional[int], optional
The minmimal length(number of frames) of the audio, by default None
return_utt_id : bool, optional
Whether to include utterance indentifier in the return value of
__getitem__, by default False
use_cache : bool, optional
Whether to cache seen examples while reading, by default False
"""
def __init__(self,
root_dir: Union[Path, str],
audio_pattern: str="*-wave.npy",
mel_pattern: str="*-mel.npy",
audio_load_fn: Callable=np.load,
mel_load_fn: Callable=np.load,
audio_length_threshold: Optional[int]=None,
mel_length_threshold: Optional[int]=None,
return_utt_id: bool=False,
use_cache: bool=False):
root_dir = Path(root_dir).expanduser()
# find all of audio and mel files
audio_files = sorted(list(root_dir.rglob(audio_pattern)))
mel_files = sorted(list(root_dir.rglob(mel_pattern)))
# filter by threshold
if audio_length_threshold is not None:
audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files]
idxs = [
idx for idx in range(len(audio_files))
if audio_lengths[idx] > audio_length_threshold
]
if len(audio_files) != len(idxs):
logging.warning(
f"Some files are filtered by audio length threshold "
f"({len(audio_files)} -> {len(idxs)}).")
audio_files = [audio_files[idx] for idx in idxs]
mel_files = [mel_files[idx] for idx in idxs]
if mel_length_threshold is not None:
mel_lengths = [mel_load_fn(f).shape[1] for f in mel_files]
idxs = [
idx for idx in range(len(mel_files))
if mel_lengths[idx] > mel_length_threshold
]
if len(mel_files) != len(idxs):
logging.warning(
f"Some files are filtered by mel length threshold "
f"({len(mel_files)} -> {len(idxs)}).")
audio_files = [audio_files[idx] for idx in idxs]
mel_files = [mel_files[idx] for idx in idxs]
# assert the number of files
assert len(
audio_files) != 0, f"Not found any audio files in {root_dir}."
assert len(audio_files) == len(mel_files), \
(f"Number of audio and mel files are different "
f"({len(audio_files)} vs {len(mel_files)}).")
self.audio_files = audio_files
self.audio_load_fn = audio_load_fn
self.mel_load_fn = mel_load_fn
self.mel_files = mel_files
if ".npy" in audio_pattern:
self.utt_ids = [
f.name.replace("-wave.npy", "") for f in audio_files
]
else:
self.utt_ids = [f.stem for f in audio_files]
self.return_utt_id = return_utt_id
self.use_cache = use_cache
if use_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [None for _ in range(len(audio_files))]
def __getitem__(self, idx: int) -> Tuple:
"""Get an example given the index.
Parameters
----------
idx : int
The index of the example.
Returns
-------
utt_id : str
Utterance identifier.
audio : np.ndarray
Shape (n_samples, ), the audio.
mel: np.ndarray
Shape (n_mels, n_frames), the mel spectrogram.
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
utt_id = self.utt_ids[idx]
audio = self.audio_load_fn(self.audio_files[idx])
mel = self.mel_load_fn(self.mel_files[idx])
if self.return_utt_id:
items = utt_id, audio, mel
else:
items = audio, mel
if self.use_cache:
self.caches[idx] = items
return items
def __len__(self):
"""Returns the size of the dataset.
Returns
-------
int
The length of the dataset
"""
return len(self.audio_files)

View File

@ -0,0 +1,150 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
from pathlib import Path
from multiprocessing import Manager
import numpy as np
from paddle.io import Dataset
class DataTable(Dataset):
def __init__(self,
data: List[Dict[str, Any]],
fields: List[str]=None,
converters: Dict[str, Callable]=None,
use_cache: bool=False):
"""Dataset to load and convert data for general purpose.
Parameters
----------
data : List[Dict[str, Any]]
Metadata, a list of meta datum, each of which is composed of
several fields
fields : List[str], optional
Fields to use, if not specified, all the fields in the data are
used, by default None
converters : Dict[str, Callable], optional
Converters used to process each field, by default None
use_cache : bool, optional
Whether to use cache, by default False
Raises
------
ValueError
If there is some field that does not exist in data.
ValueError
If there is some field in converters that does not exist in fields.
"""
# metadata
self.data = data
assert len(data) > 0, "This dataset has no examples"
# peak an example to get existing fields.
first_example = self.data[0]
fields_in_data = first_example.keys()
# check all the requested fields exist
if fields is None:
self.fields = fields_in_data
else:
for field in fields:
if field not in fields_in_data:
raise ValueError(
f"The requested field ({field}) is not found"
f"in the data. Fields in the data is {fields_in_data}")
self.fields = fields
# check converters
if converters is None:
self.converters = {}
else:
for field in converters.keys():
if field not in self.fields:
raise ValueError(
f"The converter has a non existing field ({field})")
self.converters = converters
self.use_cache = use_cache
if use_cache:
self._initialize_cache()
def _initialize_cache(self):
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [None for _ in range(len(self))]
def _get_metadata(self, idx: int) -> Dict[str, Any]:
"""Return a meta-datum given an index."""
return self.data[idx]
def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a meta datum to an example by applying the corresponding
converters to each fields requested.
Parameters
----------
meta_datum : Dict[str, Any]
Meta datum
Returns
-------
Dict[str, Any]
Converted example
"""
example = {}
for field in self.fields:
converter = self.converters.get(field, None)
meta_datum_field = meta_datum[field]
if converter is not None:
converted_field = converter(meta_datum_field)
else:
converted_field = meta_datum_field
example[field] = converted_field
return example
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get an example given an index.
Parameters
----------
idx : int
Index of the example to get
Returns
-------
Dict[str, Any]
A converted example
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
meta_datum = self._get_metadata(idx)
example = self._convert(meta_datum)
if self.use_cache:
self.caches[idx] = example
return example
def __len__(self) -> int:
"""Returns the size of the dataset.
Returns
-------
int
The length of the dataset
"""
return len(self.data)

View File

@ -1,132 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Callable, Tuple
from pathlib import Path
from multiprocessing import Manager
import numpy as np
from paddle.io import Dataset
import logging
class MelDataset(Dataset):
"""Dataset to load mel-spectrograms.
Parameters
----------
root_dir : Union[Path, str]
The root of the dataset.
mel_pattern : str, optional
A pattern to recursively find all mel feature files, by default
"*-feats.npy"
mel_length_threshold : Optional[int], optional
The minmimal length(number of frames) of the audio, by default None
mel_load_fn : Callable, optional
Function to load the audio, which takes a Path object or str as input,
by default np.load
return_utt_id : bool, optional
Whether to include utterance indentifier in the return value of
__getitem__, by default False
use_cahce : bool, optional
Whether to cache seen examples while reading, by default False
"""
def __init__(
self,
root_dir: Union[Path, str],
mel_pattern: str="*-feats.npy",
mel_length_threshold: Optional[int]=None,
mel_load_fn: Callable=np.load,
return_utt_id: bool=False,
use_cahce: bool=False, ):
# allow str and Path that contains '~'
root_dir = Path(root_dir).expanduser()
# find all of the mel files
mel_files = sorted(list(root_dir.rglob(mel_pattern)))
# filter by threshold
if mel_length_threshold is not None:
mel_lengths = [mel_load_fn(f).shape[1] for f in mel_files]
idxs = [
idx for idx in range(len(mel_files))
if mel_lengths[idx] > mel_length_threshold
]
if len(mel_files) != len(idxs):
logging.warning(
f"Some files are filtered by mel length threshold "
f"({len(mel_files)} -> {len(idxs)}).")
mel_files = [mel_files[idx] for idx in idxs]
# assert the number of files
assert len(mel_files) != 0, f"Not found any mel files in {root_dir}."
self.mel_files = mel_files
self.mel_load_fn = mel_load_fn
# TODO(chenfeiyu): better strategy to get utterance id
if ".npy" in mel_pattern:
self.utt_ids = [
f.name.replace("-feats.npy", "") for f in mel_files
]
else:
self.utt_ids = [f.stem for f in mel_files]
self.return_utt_id = return_utt_id
self.use_cache = use_cahce
if use_cahce:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [None for _ in range(len(mel_files))]
def __getitem__(self, idx):
"""Get an example given the index.
Parameters
----------
idx : int
The index
Returns
-------
utt_id : str
Utterance identifier.
audio : np.ndarray
Shape (n_mels, n_frames), the mel spectrogram.
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
utt_id = self.utt_ids[idx]
mel = self.mel_load_fn(self.mel_files[idx])
if self.return_utt_id:
items = utt_id, mel
else:
items = mel
if self.use_cache:
self.caches[idx] = items
return items
def __len__(self):
"""Returns the size of the dataset.
Returns
-------
int
The length of the dataset
"""
return len(self.mel_files)

22
tests/test_data_table.py Normal file
View File

@ -0,0 +1,22 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parakeet.datasets.data_tabel import DataTable
def test_audio_dataset():
metadata = [{'name': 'Sonic', 'v': 1000}, {'name': 'Prestol', 'v': 2000}]
converters = {'v': lambda x: x / 1000}
dataset = DataTable(metadata, fields=['v'], converters=converters)
assert dataset[0] == {'v': 1.0}