add examples: transformer_tts, waveflow, wavenet
This commit is contained in:
parent
b82217f50f
commit
28fbc60737
|
@ -0,0 +1,55 @@
|
|||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
_C.data = CN(
|
||||
dict(
|
||||
batch_size=16, # batch size
|
||||
valid_size=64, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
d_mel=80, # mel bands
|
||||
padding_idx=0, # text embedding's padding index
|
||||
mel_start_value=0.5, # value for starting frame
|
||||
mel_end_value=-0.5, # # value for ending frame
|
||||
)
|
||||
)
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
d_encoder=512, # embedding & encoder's internal size
|
||||
d_decoder=256, # decoder's internal size
|
||||
n_heads=4, # actually it can differ at each layer
|
||||
d_ffn=1024, # encoder_d_ffn & decoder_d_ffn
|
||||
encoder_layers=4, # number of transformer encoder layer
|
||||
decoder_layers=4, # number of transformer decoder layer
|
||||
d_prenet=256, # decprenet's hidden size (d_mel=>d_prenet=>d_decoder)
|
||||
d_postnet=256, # decoder postnet(cnn)'s internal channel
|
||||
postnet_layers=5, # decoder postnet(cnn)'s layer
|
||||
postnet_kernel_size=5, # decoder postnet(cnn)'s kernel size
|
||||
max_reduction_factor=10, # max_reduction factor
|
||||
dropout=0.1, # global droput probability
|
||||
stop_loss_scale=8.0, # scaler for stop _loss
|
||||
decoder_prenet_dropout=0.5, # decoder prenet dropout probability
|
||||
)
|
||||
)
|
||||
|
||||
_C.training = CN(
|
||||
dict(
|
||||
lr=1e-4, # learning rate
|
||||
drop_n_heads=[[0, 0], [15000, 1]],
|
||||
reduction_factor=[[0, 10], [80000, 4], [200000, 2]],
|
||||
plot_interval=1000, # plot attention and spectrogram
|
||||
valid_interval=1000, # validation
|
||||
save_interval=10000, # checkpoint
|
||||
max_iteration=900000, # max iteration to train
|
||||
)
|
||||
)
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _C.clone()
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_text_id
|
||||
from parakeet.data import dataset
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
"""A simple dataset adaptor for the processed ljspeech dataset."""
|
||||
def __init__(self, root):
|
||||
self.root = Path(root).expanduser()
|
||||
records = []
|
||||
with open(self.root / "metadata.pkl", 'rb') as f:
|
||||
metadata = pickle.load(f)
|
||||
for mel_name, text, phonemes, ids in metadata:
|
||||
mel_name = self.root / "mel" / (mel_name + ".npy")
|
||||
records.append((mel_name, text, phonemes, ids))
|
||||
self.records = records
|
||||
|
||||
def __getitem__(self, i):
|
||||
mel_name, _, _, ids = self.records[i]
|
||||
mel = np.load(mel_name)
|
||||
return ids, mel
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
|
||||
# decorate mel & create stop probability
|
||||
class Transform(object):
|
||||
def __init__(self, start_value, end_value):
|
||||
self.start_value = start_value
|
||||
self.end_value = end_value
|
||||
|
||||
def __call__(self, example):
|
||||
ids, mel = example # ids already have <s> and </s>
|
||||
ids = np.array(ids, dtype=np.int64)
|
||||
# add start and end frame
|
||||
mel = np.pad(mel,
|
||||
[(0, 0), (1, 1)],
|
||||
mode='constant',
|
||||
constant_values=[(0, 0), (self.start_value, self.end_value)])
|
||||
stop_labels = np.ones([mel.shape[1]], dtype=np.int64)
|
||||
stop_labels[-1] = 2
|
||||
# actually this thing can also be done within the model
|
||||
return ids, mel, stop_labels
|
||||
|
||||
|
||||
class LJSpeechCollector(object):
|
||||
"""A simple callable to batch LJSpeech examples."""
|
||||
def __init__(self, padding_idx=0, padding_value=0.):
|
||||
self.padding_idx = padding_idx
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, examples):
|
||||
ids = [example[0] for example in examples]
|
||||
mels = [example[1] for example in examples]
|
||||
stop_probs = [example[2] for example in examples]
|
||||
|
||||
ids = batch_text_id(ids, pad_id=self.padding_idx)
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx)
|
||||
return ids, np.transpose(mels, [0, 2, 1]), stop_probs
|
||||
|
||||
|
||||
def create_dataloader(config, source_path):
|
||||
lj = LJSpeech(source_path)
|
||||
transform = Transform(config.data.mel_start_value, config.data.mel_end_value)
|
||||
lj = dataset.TransformDataset(lj, transform)
|
||||
|
||||
valid_set, train_set = dataset.split(lj, config.data.valid_size)
|
||||
data_collator = LJSpeechCollector(padding_idx=config.data.padding_idx)
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator)
|
||||
valid_loader = DataLoader(
|
||||
valid_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=data_collator)
|
||||
return train_loader, valid_loader
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import tqdm
|
||||
import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor, LogMagnitude
|
||||
from parakeet.frontend import English
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
def create_dataset(config, source_path, target_path, verbose=False):
|
||||
# create output dir
|
||||
target_path = Path(target_path).expanduser()
|
||||
mel_path = target_path / "mel"
|
||||
os.makedirs(mel_path, exist_ok=True)
|
||||
|
||||
meta_data = LJSpeechMetaData(source_path)
|
||||
frontend = English()
|
||||
processor = AudioProcessor(
|
||||
sample_rate=config.data.sample_rate,
|
||||
n_fft=config.data.n_fft,
|
||||
n_mels=config.data.d_mel,
|
||||
win_length=config.data.win_length,
|
||||
hop_length=config.data.hop_length,
|
||||
f_max=config.data.f_max)
|
||||
normalizer = LogMagnitude()
|
||||
|
||||
records = []
|
||||
for (fname, text, _) in tqdm.tqdm(meta_data):
|
||||
wav = processor.read_wav(fname)
|
||||
mel = processor.mel_spectrogram(wav)
|
||||
mel = normalizer.transform(mel)
|
||||
phonemes = frontend.phoneticize(text)
|
||||
ids = frontend.numericalize(phonemes)
|
||||
mel_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
|
||||
# save mel spectrogram
|
||||
records.append((mel_name, text, phonemes, ids))
|
||||
np.save(mel_path / mel_name, mel)
|
||||
if verbose:
|
||||
print("save mel spectrograms into {}".format(mel_path))
|
||||
|
||||
# save meta data as pickle archive
|
||||
with open(target_path / "metadata.pkl", 'wb') as f:
|
||||
pickle.dump(records, f)
|
||||
if verbose:
|
||||
print("saved metadata into {}".format(target_path / "metadata.pkl"))
|
||||
|
||||
# also save meta data into text format for inspection
|
||||
with open(target_path / "metadata.txt", 'wt') as f:
|
||||
for mel_name, text, phonemes, _ in records:
|
||||
phoneme_str = "|".join(phonemes)
|
||||
f.write("{}\t{}\t{}\n".format(mel_name, text, phoneme_str))
|
||||
if verbose:
|
||||
print("saved metadata into {}".format(target_path / "metadata.txt"))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="create dataset")
|
||||
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
|
||||
parser.add_argument("--input", type=str, help="path of the ljspeech dataset")
|
||||
parser.add_argument("--output", type=str, help="path to save output dataset")
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
config = get_cfg_defaults()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config.data)
|
||||
|
||||
create_dataset(config, args.input, args.output, args.verbose)
|
|
@ -0,0 +1,64 @@
|
|||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
import parakeet
|
||||
from parakeet.frontend import English
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import scheduler
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.utils.display import add_attention_plots
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
@paddle.fluid.dygraph.no_grad
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
|
||||
# model
|
||||
frontend = English()
|
||||
model = TransformerTTS.from_pretrained(
|
||||
frontend, config, args.checkpoint_path)
|
||||
model.eval()
|
||||
|
||||
# inputs
|
||||
input_path = Path(args.input).expanduser()
|
||||
with open(input_path, "rt") as f:
|
||||
sentences = f.readlines()
|
||||
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
outputs = model.predict(sentence, verbose=args.verbose)
|
||||
mel_output = outputs["mel_output"]
|
||||
# cross_attention_weights = outputs["cross_attention_weights"]
|
||||
mel_output = mel_output.T #(C, T)
|
||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||
if args.verbose:
|
||||
print("spectrogram saved at {}".format(output_dir / f"sentence_{i}.npy"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
|
||||
parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.")
|
||||
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
|
||||
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.")
|
||||
parser.add_argument("--input", type=str, help="path of the text sentences")
|
||||
parser.add_argument("--output", type=str, help="path to save outputs")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="device type to use.")
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,202 @@
|
|||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.frontend import English
|
||||
from parakeet.models.transformer_tts import TransformerTTS, TransformerTTSLoss
|
||||
from parakeet.utils import scheduler, checkpoint, mp_tools, display
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechCollector, Transform
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
frontend = English()
|
||||
model = TransformerTTS(
|
||||
frontend,
|
||||
d_encoder=config.model.d_encoder,
|
||||
d_decoder=config.model.d_decoder,
|
||||
d_mel=config.data.d_mel,
|
||||
n_heads=config.model.n_heads,
|
||||
d_ffn=config.model.d_ffn,
|
||||
encoder_layers=config.model.encoder_layers,
|
||||
decoder_layers=config.model.decoder_layers,
|
||||
d_prenet=config.model.d_prenet,
|
||||
d_postnet=config.model.d_postnet,
|
||||
postnet_layers=config.model.postnet_layers,
|
||||
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||
max_reduction_factor=config.model.max_reduction_factor,
|
||||
decoder_prenet_dropout=config.model.decoder_prenet_dropout,
|
||||
dropout=config.model.dropout)
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=config.training.lr,
|
||||
beta1=0.9,
|
||||
beta2=0.98,
|
||||
epsilon=1e-9,
|
||||
parameters=model.parameters()
|
||||
)
|
||||
criterion = TransformerTTSLoss(config.model.stop_loss_scale)
|
||||
drop_n_heads = scheduler.StepWise(config.training.drop_n_heads)
|
||||
reduction_factor = scheduler.StepWise(config.training.reduction_factor)
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.drop_n_heads = drop_n_heads
|
||||
self.reduction_factor = reduction_factor
|
||||
|
||||
def setup_dataloader(self):
|
||||
args = self.args
|
||||
config = self.config
|
||||
|
||||
ljspeech_dataset = LJSpeech(args.data)
|
||||
transform = Transform(config.data.mel_start_value, config.data.mel_end_value)
|
||||
ljspeech_dataset = dataset.TransformDataset(ljspeech_dataset, transform)
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size)
|
||||
batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx)
|
||||
|
||||
if not self.parallel:
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
num_replicas=dist.get_world_size(),
|
||||
rank=dist.get_rank(),
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
|
||||
valid_loader = DataLoader(
|
||||
valid_set, batch_size=config.data.batch_size, collate_fn=batch_fn)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def compute_outputs(self, text, mel, stop_label):
|
||||
model_core = self.model._layers if self.parallel else self.model
|
||||
model_core.set_constants(
|
||||
self.reduction_factor(self.iteration),
|
||||
self.drop_n_heads(self.iteration))
|
||||
|
||||
# TODO(chenfeiyu): we can combine these 2 slices
|
||||
mel_input = mel[:,:-1, :]
|
||||
reduced_mel_input = mel_input[:, ::model_core.r, :]
|
||||
outputs = self.model(text, reduced_mel_input)
|
||||
return outputs
|
||||
|
||||
def compute_losses(self, inputs, outputs):
|
||||
_, mel, stop_label = inputs
|
||||
mel_target = mel[:, 1:, :]
|
||||
stop_label_target = stop_label[:, 1:]
|
||||
|
||||
mel_output = outputs["mel_output"]
|
||||
mel_intermediate = outputs["mel_intermediate"]
|
||||
stop_logits = outputs["stop_logits"]
|
||||
|
||||
time_steps = mel_target.shape[1]
|
||||
losses = self.criterion(
|
||||
mel_output[:,:time_steps, :],
|
||||
mel_intermediate[:,:time_steps, :],
|
||||
mel_target,
|
||||
stop_logits[:,:time_steps, :],
|
||||
stop_label_target)
|
||||
return losses
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
text, mel, stop_label = batch
|
||||
outputs = self.compute_outputs(text, mel, stop_label)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
loss = losses["loss"]
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
losses_np = {k: float(v) for k, v in losses.items()}
|
||||
# logging
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items())
|
||||
self.logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
for k, v in losses_np.items():
|
||||
self.visualizer.add_scalar(f"train_loss/{k}", v, self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_losses = defaultdict(list)
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
text, mel, stop_label = batch
|
||||
outputs = self.compute_outputs(text, mel, stop_label)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
for k, v in losses.items():
|
||||
valid_losses[k].append(float(v))
|
||||
|
||||
if i < 2:
|
||||
attention_weights = outputs["cross_attention_weights"]
|
||||
display.add_multi_attention_plots(
|
||||
self.visualizer,
|
||||
f"valid_sentence_{i}_cross_attention_weights",
|
||||
attention_weights,
|
||||
self.iteration)
|
||||
|
||||
# write visual log
|
||||
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
for k, v in valid_losses.items():
|
||||
self.visualizer.add_scalar(f"valid/{k}", v, self.iteration)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,43 @@
|
|||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
_C.data = CN(
|
||||
dict(
|
||||
batch_size=8, # batch size
|
||||
valid_size=16, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
clip_frames=65, # mel clip frames
|
||||
)
|
||||
)
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
upsample_factors=[16, 16],
|
||||
n_flows=8, # number of flows in WaveFlow
|
||||
n_layers=8, # number of conv block in each flow
|
||||
n_group=16, # folding factor of audio and spectrogram
|
||||
channels=128, # resiaudal channel in each flow
|
||||
kernel_size=[3, 3], # kernel size in each conv block
|
||||
sigma=1.0, # stddev of the random noise
|
||||
)
|
||||
)
|
||||
|
||||
_C.training = CN(
|
||||
dict(
|
||||
lr=2e-4, # learning rates
|
||||
valid_interval=1000, # validation
|
||||
save_interval=10000, # checkpoint
|
||||
max_iteration=3000000, # max iteration to train
|
||||
)
|
||||
)
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _C.clone()
|
|
@ -0,0 +1,78 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_wav
|
||||
from parakeet.data import dataset
|
||||
from parakeet.audio import AudioProcessor
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
"""A simple dataset adaptor for the processed ljspeech dataset."""
|
||||
def __init__(self, root):
|
||||
self.root = Path(root).expanduser()
|
||||
meta_data = pandas.read_csv(
|
||||
str(self.root / "metadata.csv"),
|
||||
sep="\t",
|
||||
header=None,
|
||||
names=["fname", "frames", "samples"]
|
||||
)
|
||||
|
||||
records = []
|
||||
for row in meta_data.itertuples() :
|
||||
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
|
||||
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
|
||||
records.append((mel_path, wav_path))
|
||||
self.records = records
|
||||
|
||||
def __getitem__(self, i):
|
||||
mel_name, wav_name = self.records[i]
|
||||
mel = np.load(mel_name)
|
||||
wav = np.load(wav_name)
|
||||
return mel, wav
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
|
||||
class LJSpeechCollector(object):
|
||||
"""A simple callable to batch LJSpeech examples."""
|
||||
def __init__(self, padding_value=0.):
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, examples):
|
||||
mels = [example[0] for example in examples]
|
||||
wavs = [example[1] for example in examples]
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs = batch_wav(wavs, pad_value=self.padding_value)
|
||||
return mels, wavs
|
||||
|
||||
|
||||
class LJSpeechClipCollector(object):
|
||||
def __init__(self, clip_frames=65, hop_length=256):
|
||||
self.clip_frames = clip_frames
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(self, examples):
|
||||
mels = []
|
||||
wavs = []
|
||||
for example in examples:
|
||||
mel_clip, wav_clip = self.clip(example)
|
||||
mels.append(mel_clip)
|
||||
wavs.append(wav_clip)
|
||||
mels = np.stack(mels)
|
||||
wavs = np.stack(wavs)
|
||||
return mels, wavs
|
||||
|
||||
def clip(self, example):
|
||||
mel, wav = example
|
||||
frames = mel.shape[-1]
|
||||
start = np.random.randint(0, frames - self.clip_frames)
|
||||
mel_clip = mel[:, start: start + self.clip_frames]
|
||||
wav_clip = wav[start * self.hop_length: (start + self.clip_frames) * self.hop_length]
|
||||
return mel_clip, wav_clip
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
import tqdm
|
||||
import csv
|
||||
import argparse
|
||||
import numpy as np
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from paddle.io import Dataset
|
||||
from parakeet.data import batch_spec, batch_wav
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor, LogMagnitude
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
class Transform(object):
|
||||
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.n_mels = n_mels
|
||||
|
||||
self.spec_normalizer = LogMagnitude(min=1e-5)
|
||||
|
||||
def __call__(self, example):
|
||||
wav_path, _, _ = example
|
||||
|
||||
sr = self.sample_rate
|
||||
n_fft = self.n_fft
|
||||
win_length = self.win_length
|
||||
hop_length = self.hop_length
|
||||
n_mels = self.n_mels
|
||||
|
||||
wav, loaded_sr = librosa.load(wav_path, sr=None)
|
||||
assert loaded_sr == sr, "sample rate does not match, resampling applied"
|
||||
|
||||
# Pad audio to the right size.
|
||||
frames = int(np.ceil(float(wav.size) / hop_length))
|
||||
fft_padding = (n_fft - hop_length) // 2 # sound
|
||||
desired_length = frames * hop_length + fft_padding * 2
|
||||
pad_amount = (desired_length - wav.size) // 2
|
||||
|
||||
if wav.size % 2 == 0:
|
||||
wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect')
|
||||
else:
|
||||
wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect')
|
||||
|
||||
# Normalize audio.
|
||||
wav = wav / np.abs(wav).max() * 0.999
|
||||
|
||||
# Compute mel-spectrogram.
|
||||
# Turn center to False to prevent internal padding.
|
||||
spectrogram = librosa.core.stft(
|
||||
wav,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_fft=n_fft,
|
||||
center=False)
|
||||
spectrogram_magnitude = np.abs(spectrogram)
|
||||
|
||||
# Compute mel-spectrograms.
|
||||
mel_filter_bank = librosa.filters.mel(sr=sr,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels)
|
||||
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
|
||||
mel_spectrogram = mel_spectrogram
|
||||
|
||||
# log scale mel_spectrogram.
|
||||
mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram)
|
||||
|
||||
# Extract the center of audio that corresponds to mel spectrograms.
|
||||
audio = wav[fft_padding:-fft_padding]
|
||||
assert mel_spectrogram.shape[1] * hop_length == audio.size
|
||||
|
||||
# there is no clipping here
|
||||
return audio, mel_spectrogram
|
||||
|
||||
|
||||
def create_dataset(config, input_dir, output_dir, verbose=True):
|
||||
input_dir = Path(input_dir).expanduser()
|
||||
dataset = LJSpeechMetaData(input_dir)
|
||||
|
||||
output_dir = Path(output_dir).expanduser()
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
transform = Transform(
|
||||
config.sample_rate,
|
||||
config.n_fft,
|
||||
config.win_length,
|
||||
config.hop_length,
|
||||
config.n_mels)
|
||||
file_names = []
|
||||
|
||||
for example in tqdm.tqdm(dataset):
|
||||
fname, _, _ = example
|
||||
base_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
wav_dir = output_dir / "wav"
|
||||
mel_dir = output_dir / "mel"
|
||||
wav_dir.mkdir(exist_ok=True)
|
||||
mel_dir.mkdir(exist_ok=True)
|
||||
|
||||
audio, mel = transform(example)
|
||||
np.save(str(wav_dir / base_name), audio)
|
||||
np.save(str(mel_dir / base_name), mel)
|
||||
|
||||
file_names.append((base_name, mel.shape[-1], audio.shape[-1]))
|
||||
|
||||
meta_data = pd.DataFrame.from_records(file_names)
|
||||
meta_data.to_csv(str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
|
||||
print("saved meta data in to {}".format(os.path.join(output_dir, "metadata.csv")))
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="create dataset")
|
||||
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
|
||||
parser.add_argument("--input", type=str, help="path of the ljspeech dataset")
|
||||
parser.add_argument("--output", type=str, help="path to save output dataset")
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
config = get_cfg_defaults()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
if args.verbose:
|
||||
print(config.data)
|
||||
print(args)
|
||||
|
||||
create_dataset(config.data, args.input, args.output, args.verbose)
|
|
@ -0,0 +1,52 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import os
|
||||
from pathlib import Path
|
||||
import paddle
|
||||
import parakeet
|
||||
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow
|
||||
from parakeet.utils import layer_tools, checkpoint
|
||||
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
model = ConditionalWaveFlow.from_pretrained(config, args.checkpoint_path)
|
||||
layer_tools.recursively_remove_weight_norm(model)
|
||||
model.eval()
|
||||
|
||||
mel_dir = Path(args.input).expanduser()
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for file_path in mel_dir.iterdir():
|
||||
mel = np.load(str(file_path))
|
||||
audio = model.predict(mel)
|
||||
audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
|
||||
sf.write(audio_path, audio, config.data.sample_rate)
|
||||
print("[synthesize] {} -> {}".format(file_path, audio_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
|
||||
parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.")
|
||||
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
|
||||
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.")
|
||||
parser.add_argument("--input", type=str, help="path of directory containing mel spectrogram (in .npy format)")
|
||||
parser.add_argument("--output", type=str, help="path to save outputs")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="device type to use.")
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,147 @@
|
|||
import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow, WaveFlowLoss
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.utils import scheduler, mp_tools
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model = ConditionalWaveFlow(
|
||||
upsample_factors=config.model.upsample_factors,
|
||||
n_flows=config.model.n_flows,
|
||||
n_layers=config.model.n_layers,
|
||||
n_group=config.model.n_group,
|
||||
channels=config.model.channels,
|
||||
n_mels=config.data.n_mels,
|
||||
kernel_size=config.model.kernel_size)
|
||||
|
||||
if self.parallel > 1:
|
||||
model = paddle.DataParallel(model)
|
||||
optimizer = paddle.optimizer.Adam(config.training.lr, parameters=model.parameters())
|
||||
criterion = WaveFlowLoss(sigma=config.model.sigma)
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config
|
||||
args = self.args
|
||||
|
||||
ljspeech_dataset = LJSpeech(args.data)
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size)
|
||||
|
||||
batch_fn = LJSpeechClipCollector(config.data.clip_frames, config.data.hop_length)
|
||||
|
||||
if not self.parallel:
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
num_replicas=dist.get_world_size(),
|
||||
rank=dist.get_rank(),
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
|
||||
valid_batch_fn = LJSpeechCollector()
|
||||
valid_loader = DataLoader(
|
||||
valid_set, batch_size=1, collate_fn=valid_batch_fn)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def compute_outputs(self, mel, wav):
|
||||
# model_core = model._layers if isinstance(model, paddle.DataParallel) else model
|
||||
z, log_det_jocobian = self.model(wav, mel)
|
||||
return z, log_det_jocobian
|
||||
|
||||
def compute_losses(self, outputs):
|
||||
loss = self.criterion(outputs)
|
||||
return loss
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.clear_grad()
|
||||
mel, wav = batch
|
||||
outputs = self.compute_outputs(mel, wav)
|
||||
loss = self.compute_losses(outputs)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
loss_value = float(loss)
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time)
|
||||
msg += "loss: {:>.6f}".format(loss_value)
|
||||
self.logger.info(msg)
|
||||
self.visualizer.add_scalar("train/loss", loss_value, global_step=self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_iterator = iter(self.valid_loader)
|
||||
valid_losses = []
|
||||
mel, wav = next(valid_iterator)
|
||||
outputs = self.compute_outputs(mel, wav)
|
||||
loss = self.compute_losses(outputs)
|
||||
valid_losses.append(float(loss))
|
||||
valid_loss = np.mean(valid_losses)
|
||||
self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,47 @@
|
|||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
_C.data = CN(
|
||||
dict(
|
||||
batch_size=8, # batch size
|
||||
valid_size=16, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=2048, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
# f_max=8000, # Hz, max frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
train_clip_seconds=0.5, # audio clip length(in seconds)
|
||||
)
|
||||
)
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
upsample_factors=[16, 16],
|
||||
n_stack=3,
|
||||
n_loop=10,
|
||||
filter_size=2,
|
||||
residual_channels=128, # resiaudal channel in each flow
|
||||
loss_type="mog",
|
||||
output_dim=3, # single gaussian
|
||||
log_scale_min=-9.0,
|
||||
)
|
||||
)
|
||||
|
||||
_C.training = CN(
|
||||
dict(
|
||||
lr=1e-3, # learning rates
|
||||
anneal_rate=0.5, # learning rate decay rate
|
||||
anneal_interval=200000, # decrese lr by annel_rate every anneal_interval steps
|
||||
valid_interval=1000, # validation
|
||||
save_interval=10000, # checkpoint
|
||||
max_iteration=3000000, # max iteration to train
|
||||
gradient_max_norm=100.0 # global norm of gradients
|
||||
)
|
||||
)
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _C.clone()
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_wav
|
||||
from parakeet.data import dataset
|
||||
from parakeet.audio import AudioProcessor
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
"""A simple dataset adaptor for the processed ljspeech dataset."""
|
||||
def __init__(self, root):
|
||||
self.root = Path(root).expanduser()
|
||||
meta_data = pandas.read_csv(
|
||||
str(self.root / "metadata.csv"),
|
||||
sep="\t",
|
||||
header=None,
|
||||
names=["fname", "frames", "samples"]
|
||||
)
|
||||
|
||||
records = []
|
||||
for row in meta_data.itertuples() :
|
||||
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
|
||||
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
|
||||
records.append((mel_path, wav_path))
|
||||
self.records = records
|
||||
|
||||
def __getitem__(self, i):
|
||||
mel_name, wav_name = self.records[i]
|
||||
mel = np.load(mel_name)
|
||||
wav = np.load(wav_name)
|
||||
return mel, wav
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
|
||||
class LJSpeechCollector(object):
|
||||
"""A simple callable to batch LJSpeech examples."""
|
||||
def __init__(self, padding_value=0.):
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, examples):
|
||||
batch_size = len(examples)
|
||||
mels = [example[0] for example in examples]
|
||||
wavs = [example[1] for example in examples]
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs = batch_wav(wavs, pad_value=self.padding_value)
|
||||
audio_starts = np.zeros((batch_size,), dtype=np.int64)
|
||||
return mels, wavs, audio_starts
|
||||
|
||||
|
||||
class LJSpeechClipCollector(object):
|
||||
def __init__(self, clip_frames=65, hop_length=256):
|
||||
self.clip_frames = clip_frames
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(self, examples):
|
||||
mels = []
|
||||
wavs = []
|
||||
starts = []
|
||||
for example in examples:
|
||||
mel, wav_clip, start = self.clip(example)
|
||||
mels.append(mel)
|
||||
wavs.append(wav_clip)
|
||||
starts.append(start)
|
||||
mels = batch_spec(mels)
|
||||
wavs = np.stack(wavs)
|
||||
starts = np.array(starts, dtype=np.int64)
|
||||
return mels, wavs, starts
|
||||
|
||||
def clip(self, example):
|
||||
mel, wav = example
|
||||
frames = mel.shape[-1]
|
||||
start = np.random.randint(0, frames - self.clip_frames)
|
||||
wav_clip = wav[start * self.hop_length: (start + self.clip_frames) * self.hop_length]
|
||||
return mel, wav_clip, start
|
||||
|
||||
|
||||
class DataCollector(object):
|
||||
def __init__(self,
|
||||
context_size,
|
||||
sample_rate,
|
||||
hop_length,
|
||||
train_clip_seconds,
|
||||
valid=False):
|
||||
frames_per_second = sample_rate // hop_length
|
||||
train_clip_frames = int(
|
||||
np.ceil(train_clip_seconds * frames_per_second))
|
||||
context_frames = context_size // hop_length
|
||||
self.num_frames = train_clip_frames + context_frames
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.valid = valid
|
||||
|
||||
def random_crop(self, sample):
|
||||
audio, mel_spectrogram = sample
|
||||
audio_frames = int(audio.size) // self.hop_length
|
||||
max_start_frame = audio_frames - self.num_frames
|
||||
assert max_start_frame >= 0, "audio is too short to be cropped"
|
||||
|
||||
frame_start = np.random.randint(0, max_start_frame)
|
||||
# frame_start = 0 # norandom
|
||||
frame_end = frame_start + self.num_frames
|
||||
|
||||
audio_start = frame_start * self.hop_length
|
||||
audio_end = frame_end * self.hop_length
|
||||
|
||||
audio = audio[audio_start:audio_end]
|
||||
return audio, mel_spectrogram, audio_start
|
||||
|
||||
def __call__(self, samples):
|
||||
# transform them first
|
||||
if self.valid:
|
||||
samples = [(audio, mel_spectrogram, 0)
|
||||
for audio, mel_spectrogram in samples]
|
||||
else:
|
||||
samples = [self.random_crop(sample) for sample in samples]
|
||||
# batch them
|
||||
audios = [sample[0] for sample in samples]
|
||||
audio_starts = [sample[2] for sample in samples]
|
||||
mels = [sample[1] for sample in samples]
|
||||
|
||||
mels = batch_spec(mels)
|
||||
|
||||
if self.valid:
|
||||
audios = batch_wav(audios, dtype=np.float32)
|
||||
else:
|
||||
audios = np.array(audios, dtype=np.float32)
|
||||
audio_starts = np.array(audio_starts, dtype=np.int64)
|
||||
return audios, mels, audio_starts
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
import os
|
||||
import tqdm
|
||||
import csv
|
||||
import argparse
|
||||
import numpy as np
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from paddle.io import Dataset
|
||||
from parakeet.data import batch_spec, batch_wav
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.audio.spec_normalizer import UnitMagnitude
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
class Transform(object):
|
||||
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.n_mels = n_mels
|
||||
|
||||
self.spec_normalizer = UnitMagnitude(min=1e-5)
|
||||
|
||||
def __call__(self, example):
|
||||
wav_path, _, _ = example
|
||||
|
||||
sr = self.sample_rate
|
||||
n_fft = self.n_fft
|
||||
win_length = self.win_length
|
||||
hop_length = self.hop_length
|
||||
n_mels = self.n_mels
|
||||
|
||||
wav, loaded_sr = librosa.load(wav_path, sr=None)
|
||||
assert loaded_sr == sr, "sample rate does not match, resampling applied"
|
||||
|
||||
# Pad audio to the right size.
|
||||
frames = int(np.ceil(float(wav.size) / hop_length))
|
||||
fft_padding = (n_fft - hop_length) // 2 # sound
|
||||
desired_length = frames * hop_length + fft_padding * 2
|
||||
pad_amount = (desired_length - wav.size) // 2
|
||||
|
||||
if wav.size % 2 == 0:
|
||||
wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect')
|
||||
else:
|
||||
wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect')
|
||||
|
||||
# Normalize audio.
|
||||
wav = wav / np.abs(wav).max() * 0.999
|
||||
|
||||
# Compute mel-spectrogram.
|
||||
# Turn center to False to prevent internal padding.
|
||||
spectrogram = librosa.core.stft(
|
||||
wav,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_fft=n_fft,
|
||||
center=False)
|
||||
spectrogram_magnitude = np.abs(spectrogram)
|
||||
|
||||
# Compute mel-spectrograms.
|
||||
mel_filter_bank = librosa.filters.mel(sr=sr,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels)
|
||||
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
|
||||
mel_spectrogram = mel_spectrogram
|
||||
|
||||
# log scale mel_spectrogram.
|
||||
mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram)
|
||||
|
||||
# Extract the center of audio that corresponds to mel spectrograms.
|
||||
audio = wav[fft_padding:-fft_padding]
|
||||
assert mel_spectrogram.shape[1] * hop_length == audio.size
|
||||
|
||||
# there is no clipping here
|
||||
return audio, mel_spectrogram
|
||||
|
||||
|
||||
def create_dataset(config, input_dir, output_dir, verbose=True):
|
||||
input_dir = Path(input_dir).expanduser()
|
||||
dataset = LJSpeechMetaData(input_dir)
|
||||
|
||||
output_dir = Path(output_dir).expanduser()
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
transform = Transform(
|
||||
config.sample_rate,
|
||||
config.n_fft,
|
||||
config.win_length,
|
||||
config.hop_length,
|
||||
config.n_mels)
|
||||
file_names = []
|
||||
|
||||
for example in tqdm.tqdm(dataset):
|
||||
fname, _, _ = example
|
||||
base_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
wav_dir = output_dir / "wav"
|
||||
mel_dir = output_dir / "mel"
|
||||
wav_dir.mkdir(exist_ok=True)
|
||||
mel_dir.mkdir(exist_ok=True)
|
||||
|
||||
audio, mel = transform(example)
|
||||
np.save(str(wav_dir / base_name), audio)
|
||||
np.save(str(mel_dir / base_name), mel)
|
||||
|
||||
file_names.append((base_name, mel.shape[-1], audio.shape[-1]))
|
||||
|
||||
meta_data = pd.DataFrame.from_records(file_names)
|
||||
meta_data.to_csv(str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
|
||||
print("saved meta data in to {}".format(os.path.join(output_dir, "metadata.csv")))
|
||||
|
||||
print("Done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="create dataset")
|
||||
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
|
||||
parser.add_argument("--input", type=str, help="path of the ljspeech dataset")
|
||||
parser.add_argument("--output", type=str, help="path to save output dataset")
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
config = get_cfg_defaults()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
if args.verbose:
|
||||
print(config.data)
|
||||
print(args)
|
||||
|
||||
create_dataset(config.data, args.input, args.output, args.verbose)
|
|
@ -0,0 +1,51 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import os
|
||||
from pathlib import Path
|
||||
import paddle
|
||||
import parakeet
|
||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet
|
||||
from parakeet.utils import layer_tools, checkpoint
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
model = ConditionalWaveNet.from_pretrained(config, args.checkpoint_path)
|
||||
layer_tools.recursively_remove_weight_norm(model)
|
||||
model.eval()
|
||||
|
||||
mel_dir = Path(args.input).expanduser()
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for file_path in mel_dir.iterdir():
|
||||
mel = np.load(str(file_path))
|
||||
audio = model.predict(mel)
|
||||
audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav")
|
||||
sf.write(audio_path, audio, config.data.sample_rate)
|
||||
print("[synthesize] {} -> {}".format(file_path, audio_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
|
||||
parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.")
|
||||
parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config")
|
||||
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.")
|
||||
parser.add_argument("--input", type=str, help="path of directory containing mel spectrogram (in .npy format)")
|
||||
parser.add_argument("--output", type=str, help="path to save outputs")
|
||||
parser.add_argument("--device", type=str, default="cpu", help="device type to use.")
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,157 @@
|
|||
import time
|
||||
from pathlib import Path
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.utils import scheduler, mp_tools
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model = ConditionalWaveNet(
|
||||
upsample_factors=config.model.upsample_factors,
|
||||
n_stack=config.model.n_stack,
|
||||
n_loop=config.model.n_loop,
|
||||
residual_channels=config.model.residual_channels,
|
||||
output_dim=config.model.output_dim,
|
||||
n_mels=config.data.n_mels,
|
||||
filter_size=config.model.filter_size,
|
||||
loss_type=config.model.loss_type,
|
||||
log_scale_min=config.model.log_scale_min)
|
||||
|
||||
if self.parallel > 1:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
lr_scheduler = paddle.optimizer.lr.StepDecay(
|
||||
config.training.lr,
|
||||
config.training.anneal_interval,
|
||||
config.training.anneal_rate)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
parameters=model.parameters(),
|
||||
grad_clip=paddle.nn.ClipGradByGlobalNorm(config.training.gradient_max_norm))
|
||||
|
||||
self.model = model
|
||||
self.model_core = model._layer if self.parallel else model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config
|
||||
args = self.args
|
||||
|
||||
ljspeech_dataset = LJSpeech(args.data)
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size)
|
||||
|
||||
# convolutional net's causal padding size
|
||||
context_size = config.model.n_stack \
|
||||
* sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
|
||||
+ 1
|
||||
context_frames = context_size // config.data.hop_length
|
||||
|
||||
# frames used to compute loss
|
||||
frames_per_second = config.data.sample_rate // config.data.hop_length
|
||||
train_clip_frames = math.ceil(config.data.train_clip_seconds * frames_per_second)
|
||||
|
||||
num_frames = train_clip_frames + context_frames
|
||||
batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
|
||||
if not self.parallel:
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
|
||||
valid_batch_fn = LJSpeechCollector()
|
||||
valid_loader = DataLoader(
|
||||
valid_set, batch_size=1, collate_fn=valid_batch_fn)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.clear_grad()
|
||||
mel, wav, audio_starts = batch
|
||||
|
||||
y = self.model(wav, mel, audio_starts)
|
||||
loss = self.model.loss(y, wav)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
loss_value = float(loss)
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time)
|
||||
msg += "loss: {:>.6f}".format(loss_value)
|
||||
self.logger.info(msg)
|
||||
self.visualizer.add_scalar("train/loss", loss_value, global_step=self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_iterator = iter(self.valid_loader)
|
||||
valid_losses = []
|
||||
mel, wav, audio_starts = next(valid_iterator)
|
||||
y = self.model(wav, mel, audio_starts)
|
||||
loss = self.model.loss(y, wav)
|
||||
valid_losses.append(float(loss))
|
||||
valid_loss = np.mean(valid_losses)
|
||||
self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
Loading…
Reference in New Issue