Merge pull request #62 from lfchener/develop

add example for tacotron2
This commit is contained in:
Li Fuchen 2020-12-18 20:00:44 +08:00 committed by GitHub
commit cf43f2cf03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 575 additions and 0 deletions

View File

@ -0,0 +1,70 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from yacs.config import CfgNode as CN
_C = CN()
_C.data = CN(
dict(
batch_size=32, # batch size
valid_size=64, # the first N examples are reserved for validation
sample_rate=22050, # Hz, sample rate
n_fft=1024, # fft frame size
win_length=1024, # window size
hop_length=256, # hop size between ajacent frame
f_max=8000, # Hz, max frequency when converting to mel
f_min=0, # Hz, min frequency when converting to mel
d_mels=80, # mel bands
padding_idx=0, # text embedding's padding index
))
_C.model = CN(
dict(
reduction_factor=1, # reduction factor
d_encoder=512, # embedding & encoder's internal size
encoder_conv_layers=3, # number of conv layer in tacotron2 encoder
encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder
d_prenet=256, # hidden size of decoder prenet
d_attention_rnn=1024, # hidden size of the first rnn layer in tacotron2 decoder
d_decoder_rnn=1024, #hidden size of the second rnn layer in tacotron2 decoder
d_attention=128, # hidden size of decoder location linear layer
attention_filters=32, # number of filter in decoder location conv layer
attention_kernel_size=31, # kernel size of decoder location conv layer
d_postnet=512, # hidden size of decoder postnet
postnet_kernel_size=5, # kernel size of conv layers in postnet
postnet_conv_layers=5, # number of conv layer in decoder postnet
p_encoder_dropout=0.5, # droput probability in encoder
p_prenet_dropout=0.5, # droput probability in decoder prenet
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
p_postnet_dropout=0.5, #droput probability in decoder postnet
))
_C.training = CN(
dict(
lr=1e-3, # learning rate
weight_decay=1e-6, # the coeff of weight decay
grad_clip_thresh=1.0, # the clip norm of grad clip.
plot_interval=1000, # plot attention and spectrogram
valid_interval=1000, # validation
save_interval=1000, # checkpoint
max_iteration=500000, # max iteration to train
))
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
return _C.clone()

View File

@ -0,0 +1,106 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
import pickle
import numpy as np
from paddle.io import Dataset, DataLoader
from parakeet.data.batch import batch_spec, batch_text_id
from parakeet.data import dataset
class LJSpeech(Dataset):
"""A simple dataset adaptor for the processed ljspeech dataset."""
def __init__(self, root):
self.root = Path(root).expanduser()
records = []
with open(self.root / "metadata.pkl", 'rb') as f:
metadata = pickle.load(f)
for mel_name, text, ids in metadata:
mel_name = self.root / "mel" / (mel_name + ".npy")
records.append((mel_name, text, ids))
self.records = records
def __getitem__(self, i):
mel_name, _, ids = self.records[i]
mel = np.load(mel_name)
return ids, mel
def __len__(self):
return len(self.records)
class LJSpeechCollector(object):
"""A simple callable to batch LJSpeech examples."""
def __init__(self, padding_idx=0, padding_value=0.,
padding_stop_token=1.0):
self.padding_idx = padding_idx
self.padding_value = padding_value
self.padding_stop_token = padding_stop_token
def __call__(self, examples):
texts = []
mels = []
text_lens = []
mel_lens = []
stop_tokens = []
for data in examples:
text, mel = data
text = np.array(text, dtype=np.int64)
text_lens.append(len(text))
mels.append(mel)
texts.append(text)
mel_lens.append(mel.shape[1])
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
stop_tokens.append(np.append(stop_token, 1.0))
# Sort by text_len in descending order
texts = [
i
for i, _ in sorted(
zip(texts, text_lens), key=lambda x: x[1], reverse=True)
]
mels = [
i
for i, _ in sorted(
zip(mels, text_lens), key=lambda x: x[1], reverse=True)
]
mel_lens = [
i
for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
]
stop_tokens = [
i
for i, _ in sorted(
zip(stop_tokens, text_lens), key=lambda x: x[1], reverse=True)
]
text_lens = sorted(text_lens, reverse=True)
# Pad sequence with largest len of the batch
texts = batch_text_id(texts, pad_id=self.padding_idx)
mels = np.transpose(
batch_spec(
mels, pad_value=self.padding_value), axes=(0, 2, 1))
stop_tokens = batch_text_id(
stop_tokens, pad_id=self.padding_stop_token, dtype=mels[0].dtype)
return (texts, mels, text_lens, mel_lens, stop_tokens)

View File

@ -0,0 +1,99 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tqdm
import pickle
import argparse
import numpy as np
from pathlib import Path
from parakeet.datasets import LJSpeechMetaData
from parakeet.audio import AudioProcessor, LogMagnitude
from parakeet.frontend import EnglishCharacter
from config import get_cfg_defaults
def create_dataset(config, source_path, target_path, verbose=False):
# create output dir
target_path = Path(target_path).expanduser()
mel_path = target_path / "mel"
os.makedirs(mel_path, exist_ok=True)
meta_data = LJSpeechMetaData(source_path)
frontend = EnglishCharacter()
processor = AudioProcessor(
sample_rate=config.data.sample_rate,
n_fft=config.data.n_fft,
n_mels=config.data.d_mels,
win_length=config.data.win_length,
hop_length=config.data.hop_length,
f_max=config.data.f_max,
f_min=config.data.f_min)
normalizer = LogMagnitude()
records = []
for (fname, text, _) in tqdm.tqdm(meta_data):
wav = processor.read_wav(fname)
mel = processor.mel_spectrogram(wav)
mel = normalizer.transform(mel)
ids = frontend(text)
mel_name = os.path.splitext(os.path.basename(fname))[0]
# save mel spectrogram
records.append((mel_name, text, ids))
np.save(mel_path / mel_name, mel)
if verbose:
print("save mel spectrograms into {}".format(mel_path))
# save meta data as pickle archive
with open(target_path / "metadata.pkl", 'wb') as f:
pickle.dump(records, f)
if verbose:
print("saved metadata into {}".format(target_path /
"metadata.pkl"))
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="create dataset")
parser.add_argument(
"--config",
type=str,
metavar="FILE",
help="extra config to overwrite the default config")
parser.add_argument(
"--input", type=str, help="path of the ljspeech dataset")
parser.add_argument(
"--output", type=str, help="path to save output dataset")
parser.add_argument(
"--opts",
nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="print msg")
config = get_cfg_defaults()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config.data)
create_dataset(config, args.input, args.output, args.verbose)

View File

@ -0,0 +1,89 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import numpy as np
import paddle
import parakeet
from parakeet.frontend import EnglishCharacter
from parakeet.models.tacotron2 import Tacotron2
from config import get_cfg_defaults
def main(config, args):
paddle.set_device(args.device)
# model
frontend = EnglishCharacter()
model = Tacotron2.from_pretrained(frontend, config, args.checkpoint_path)
model.eval()
# inputs
input_path = Path(args.input).expanduser()
with open(input_path, "rt") as f:
sentences = f.readlines()
if args.output is None:
output_dir = input_path.parent / "synthesis"
else:
output_dir = Path(args.output).expanduser()
output_dir.mkdir(exist_ok=True)
for i, sentence in enumerate(sentences):
mel_output, _ = model.predict(sentence)
mel_output = mel_output.T
np.save(str(output_dir / f"sentence_{i}"), mel_output)
if args.verbose:
print("spectrogram saved at {}".format(output_dir /
f"sentence_{i}.npy"))
if __name__ == "__main__":
config = get_cfg_defaults()
parser = argparse.ArgumentParser(
description="generate mel spectrogram with TransformerTTS.")
parser.add_argument(
"--config",
type=str,
metavar="FILE",
help="extra config to overwrite the default config")
parser.add_argument(
"--checkpoint_path", type=str, help="path of the checkpoint to load.")
parser.add_argument("--input", type=str, help="path of the text sentences")
parser.add_argument("--output", type=str, help="path to save outputs")
parser.add_argument(
"--device", type=str, default="cpu", help="device type to use.")
parser.add_argument(
"--opts",
nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="print msg")
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)

211
examples/tacotron2/train.py Normal file
View File

@ -0,0 +1,211 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections import defaultdict
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
import parakeet
from parakeet.data import dataset
from parakeet.frontend import EnglishCharacter
from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase
from parakeet.utils import display, mp_tools
from parakeet.models.tacotron2 import Tacotron2, Tacotron2Loss
from config import get_cfg_defaults
from ljspeech import LJSpeech, LJSpeechCollector
class Experiment(ExperimentBase):
def compute_losses(self, inputs, outputs):
_, mel_targets, _, _, stop_tokens = inputs
mel_outputs = outputs["mel_output"]
mel_outputs_postnet = outputs["mel_outputs_postnet"]
stop_logits = outputs["stop_logits"]
losses = self.criterion(mel_outputs, mel_outputs_postnet, stop_logits,
mel_targets, stop_tokens)
return losses
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.optimizer.clear_grad()
self.model.train()
texts, mels, text_lens, output_lens, stop_tokens = batch
outputs = self.model(texts, mels, text_lens, output_lens)
losses = self.compute_losses(batch, outputs)
loss = losses["loss"]
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
losses_np = {k: float(v) for k, v in losses.items()}
# logging
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
self.logger.info(msg)
if dist.get_rank() == 0:
for k, v in losses_np.items():
self.visualizer.add_scalar(f"train_loss/{k}", v,
self.iteration)
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader):
texts, mels, text_lens, output_lens, stop_tokens = batch
outputs = self.model(texts, mels, text_lens, output_lens)
losses = self.compute_losses(batch, outputs)
for k, v in losses.items():
valid_losses[k].append(float(v))
attention_weights = outputs["alignments"]
display.add_attention_plots(self.visualizer,
f"valid_sentence_{i}_alignments",
attention_weights[0], self.iteration)
display.add_spectrogram_plots(
self.visualizer, f"valid_sentence_{i}_target_spectrogram",
mels[0], self.iteration)
display.add_spectrogram_plots(
self.visualizer, f"valid_sentence_{i}_predicted_spectrogram",
outputs['mel_outputs_postnet'][0], self.iteration)
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
# logging
msg = "Valid: "
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
self.logger.info(msg)
for k, v in valid_losses.items():
self.visualizer.add_scalar(f"valid/{k}", v, self.iteration)
def setup_model(self):
config = self.config
frontend = EnglishCharacter()
model = Tacotron2(
frontend,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers,
encoder_kernel_size=config.model.encoder_kernel_size,
d_prenet=config.model.d_prenet,
d_attention_rnn=config.model.d_attention_rnn,
d_decoder_rnn=config.model.d_decoder_rnn,
attention_filters=config.model.attention_filters,
attention_kernel_size=config.model.attention_kernel_size,
d_attention=config.model.d_attention,
d_postnet=config.model.d_postnet,
postnet_kernel_size=config.model.postnet_kernel_size,
postnet_conv_layers=config.model.postnet_conv_layers,
reduction_factor=config.model.reduction_factor,
p_encoder_dropout=config.model.p_encoder_dropout,
p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout)
if self.parallel:
model = paddle.DataParallel(model)
grad_clip = paddle.nn.ClipGradByGlobalNorm(
config.training.grad_clip_thresh)
optimizer = paddle.optimizer.Adam(
learning_rate=config.training.lr,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
criterion = Tacotron2Loss()
self.model = model
self.optimizer = optimizer
self.criterion = criterion
def setup_dataloader(self):
args = self.args
config = self.config
ljspeech_dataset = LJSpeech(args.data)
valid_set, train_set = dataset.split(ljspeech_dataset,
config.data.valid_size)
batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx)
if not self.parallel:
self.train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=batch_fn)
else:
sampler = DistributedBatchSampler(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True)
self.train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn)
self.valid_loader = DataLoader(
valid_set,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=batch_fn)
def main_sp(config, args):
exp = Experiment(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.nprocs > 1 and args.device == "gpu":
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = default_argument_parser()
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)