diff --git a/.gitignore b/.gitignore index 25f2656..7906666 100644 --- a/.gitignore +++ b/.gitignore @@ -4,9 +4,6 @@ *.udb *.ann -# data -datasets/ - # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/parakeet/datasets/__init__.py b/parakeet/datasets/__init__.py new file mode 100644 index 0000000..de7be70 --- /dev/null +++ b/parakeet/datasets/__init__.py @@ -0,0 +1,2 @@ +from parakeet.datasets.common import * +from parakeet.datasets.ljspeech import * \ No newline at end of file diff --git a/parakeet/datasets/common.py b/parakeet/datasets/common.py new file mode 100644 index 0000000..f923086 --- /dev/null +++ b/parakeet/datasets/common.py @@ -0,0 +1,21 @@ +from paddle.io import Dataset +import os +import librosa + +class AudioFolderDataset(Dataset): + def __init__(self, path, sample_rate, extension="wav"): + self.root = os.path.expanduser(path) + self.sample_rate = sample_rate + self.extension = extension + self.file_names = [ + os.path.join(self.root, x) for x in os.listdir(self.root) \ + if os.path.splitext(x)[-1] == self.extension] + self.length = len(self.file_names) + + def __len__(self): + return self.length + + def __getitem__(self, i): + file_name = self.file_names[i] + y, _ = librosa.load(file_name, sr=self.sample_rate) # pylint: disable=unused-variable + return y diff --git a/parakeet/datasets/ljspeech.py b/parakeet/datasets/ljspeech.py new file mode 100644 index 0000000..7011063 --- /dev/null +++ b/parakeet/datasets/ljspeech.py @@ -0,0 +1,23 @@ +from paddle.io import Dataset +from pathlib import Path + +class LJSpeechMetaData(Dataset): + def __init__(self, root): + self.root = Path(root).expanduser() + wav_dir = self.root / "wavs" + csv_path = self.root / "metadata.csv" + records = [] + speaker_name = "ljspeech" + with open(str(csv_path), 'rt') as f: + for line in f: + filename, _, normalized_text = line.strip().split("|") + filename = str(wav_dir / (filename + ".wav")) + records.append([filename, normalized_text, speaker_name]) + self.records = records + + def __getitem__(self, i): + return self.records[i] + + def __len__(self): + return len(self.records) +