Parakeet/examples/deepvoice3/data.py

111 lines
3.8 KiB
Python
Raw Normal View History

2020-07-10 20:22:43 +08:00
import numpy as np
2020-02-13 10:24:34 +08:00
import os
import csv
import pandas as pd
2020-07-10 20:22:43 +08:00
import paddle
from paddle import fluid
from paddle.fluid import dygraph as dg
from paddle.fluid.dataloader import Dataset, BatchSampler
from paddle.fluid.io import DataLoader
2020-02-13 10:24:34 +08:00
2020-07-10 20:22:43 +08:00
from parakeet.data import DatasetMixin, DataCargo, PartialyRandomizedSimilarTimeLengthSampler
from parakeet.g2p import en
2020-02-13 10:24:34 +08:00
2020-07-10 20:22:43 +08:00
class LJSpeech(DatasetMixin):
2020-02-13 10:24:34 +08:00
def __init__(self, root):
2020-07-10 20:22:43 +08:00
self._root = root
2020-02-13 10:24:34 +08:00
self._table = pd.read_csv(
2020-07-10 20:22:43 +08:00
os.path.join(root, "metadata.csv"),
sep="|",
encoding="utf-8",
quoting=csv.QUOTE_NONE,
header=None,
names=["num_frames", "spec_name", "mel_name", "text"],
dtype={"num_frames": np.int64, "spec_name": str, "mel_name":str, "text":str})
def num_frames(self):
return self._table["num_frames"].to_list()
2020-02-13 10:24:34 +08:00
def get_example(self, i):
2020-07-10 20:22:43 +08:00
"""
spec (T_frame, C_spec)
mel (T_frame, C_mel)
"""
num_frames, spec_name, mel_name, text = self._table.iloc[i]
spec = np.load(os.path.join(self._root, spec_name))
mel = np.load(os.path.join(self._root, mel_name))
return (text, spec, mel, num_frames)
2020-02-13 10:24:34 +08:00
def __len__(self):
return len(self._table)
class DataCollector(object):
2020-07-10 20:22:43 +08:00
def __init__(self, p_pronunciation):
self.p_pronunciation = p_pronunciation
2020-02-13 10:24:34 +08:00
def __call__(self, examples):
2020-07-10 20:22:43 +08:00
"""
output shape and dtype
(B, T_text) int64
(B,) int64
(B, T_frame, C_spec) float32
(B, T_frame, C_mel) float32
(B,) int64
"""
text_seqs = []
specs = []
mels = []
num_frames = np.array([example[3] for example in examples], dtype=np.int64)
max_frames = np.max(num_frames)
2020-02-13 10:24:34 +08:00
for example in examples:
2020-07-10 20:22:43 +08:00
text, spec, mel, _ = example
text_seqs.append(en.text_to_sequence(text, self.p_pronunciation))
# if max_frames - mel.shape[0] < 0:
# import pdb; pdb.set_trace()
specs.append(np.pad(spec, [(0, max_frames - spec.shape[0]), (0, 0)]))
mels.append(np.pad(mel, [(0, max_frames - mel.shape[0]), (0, 0)]))
specs = np.stack(specs)
mels = np.stack(mels)
text_lengths = np.array([len(seq) for seq in text_seqs], dtype=np.int64)
max_length = np.max(text_lengths)
text_seqs = np.array([seq + [0] * (max_length - len(seq)) for seq in text_seqs], dtype=np.int64)
return text_seqs, text_lengths, specs, mels, num_frames
if __name__ == "__main__":
import argparse
import tqdm
import time
from ruamel import yaml
parser = argparse.ArgumentParser(description="load the preprocessed ljspeech dataset")
parser.add_argument("--config", type=str, required=True, help="config file")
parser.add_argument("--input", type=str, required=True, help="data path of the original data")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = yaml.safe_load(f)
print("========= Command Line Arguments ========")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
print("=========== Configurations ==============")
for k in ["p_pronunciation", "batch_size"]:
print("{}: {}".format(k, config[k]))
ljspeech = LJSpeech(args.input)
collate_fn = DataCollector(config["p_pronunciation"])
dg.enable_dygraph(fluid.CPUPlace())
sampler = PartialyRandomizedSimilarTimeLengthSampler(ljspeech.num_frames())
cargo = DataCargo(ljspeech, collate_fn,
batch_size=config["batch_size"], sampler=sampler)
loader = DataLoader\
.from_generator(capacity=5, return_list=True)\
.set_batch_generator(cargo)
for i, batch in tqdm.tqdm(enumerate(loader)):
continue