add fastspeech2 example

This commit is contained in:
TianYuan 2021-07-19 06:31:52 +00:00
parent 6553d1d723
commit 474bc4c06a
6 changed files with 565 additions and 69 deletions

View File

@ -0,0 +1,47 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from parakeet.data.batch import batch_sequences
def collate_baker_examples(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
energy = [np.array(item["energy"], dtype=np.float32) for item in examples]
durations = [
np.array(
item["durations"], dtype=np.int64) for item in examples
]
text_lengths = np.array([item["text_lengths"] for item in examples])
speech_lengths = np.array([item["speech_lengths"] for item in examples])
text = batch_sequences(text)
pitch = batch_sequences(pitch)
speech = batch_sequences(speech)
durations = batch_sequences(durations)
energy = batch_sequences(energy)
batch = {
"text": text,
"text_lengths": text_lengths,
"durations": durations,
"speech": speech,
"speech_lengths": speech_lengths,
"pitch": pitch,
"energy": energy
}
return batch

View File

@ -0,0 +1,104 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 2048 # FFT size.
n_shift: 300 # Hop size.
win_length: 1200 # Window length.
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 80 # Minimum frequency of Mel basis.
fmax: 7600 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Maximum f0 for pitch extraction.
f0max: 400 # Minimum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 4
###########################################################
# MODEL SETTING #
###########################################################
model:
adim: 384 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1536 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1536 # number of decoder ff units
positionwise_layer_type: conv1d # type of position-wise layer
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
use_masking: True # whether to apply masking for padded part in loss calculation
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer
transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding
transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder
###########################################################
# UPDATER SETTING #
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 1000
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

View File

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from yacs.config import CfgNode as Configuration
with open("conf/default.yaml", 'rt') as f:
_C = yaml.safe_load(f)
_C = Configuration(_C)
def get_cfg_default():
config = _C.clone()
return config
print(get_cfg_default())

View File

@ -0,0 +1,120 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.nn import functional as F
from paddle.fluid.layers import huber_loss
from parakeet.modules.losses import masked_l1_loss, weighted_mean
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.models.fastspeech2_new import FastSpeech2, FastSpeech2Loss
class FastSpeech2Updater(StandardUpdater):
def __init__(self,
model,
optimizer,
dataloader,
init_state=None,
use_masking=False,
use_weighted_masking=False):
super().__init__(model, optimizer, dataloader, init_state=None)
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
def update_core(self, batch):
before_outs, after_outs, d_outs, p_outs, e_outs, ys, ilens, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"], )
criterion = FastSpeech2Loss(
use_masking=self.use_masking,
use_weighted_masking=self.use_weighted_masking)
l1_loss, duration_loss, pitch_loss, energy_loss = criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=ilens,
olens=olens, )
loss = l1_loss + duration_loss + pitch_loss + energy_loss
optimizer = self.optimizer
optimizer.clear_grad()
loss.backward()
optimizer.step()
# import pdb; pdb.set_trace()
report("train/loss", float(loss))
report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
class FastSpeech2Evaluator(StandardEvaluator):
def __init__(self,
model,
dataloader,
use_masking=False,
use_weighted_masking=False):
super().__init__(model, dataloader)
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
def evaluate_core(self, batch):
before_outs, after_outs, d_outs, p_outs, e_outs, ys, ilens, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"], )
criterion = FastSpeech2Loss(
use_masking=self.use_masking,
use_weighted_masking=self.use_weighted_masking)
l1_loss, duration_loss, pitch_loss, energy_loss = criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=ilens,
olens=olens, )
loss = l1_loss + duration_loss + pitch_loss + energy_loss
# import pdb; pdb.set_trace()
report("eval/loss", float(loss))
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))

View File

@ -0,0 +1,228 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
import dataclasses
from pathlib import Path
import yaml
import jsonlines
import paddle
import numpy as np
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay
from paddle import DataParallel
from visualdl import LogWriter
from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2_new import FastSpeech2
from parakeet.training.updater import UpdaterBase
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training import extension
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything
from batch_fn import collate_baker_examples
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
from config import get_cfg_default
optim_classes = dict(
adadelta=paddle.optimizer.Adadelta,
adagrad=paddle.optimizer.Adagrad,
adam=paddle.optimizer.Adam,
adamax=paddle.optimizer.Adamax,
adamw=paddle.optimizer.AdamW,
lamb=paddle.optimizer.Lamb,
momentum=paddle.optimizer.Momentum,
rmsprop=paddle.optimizer.RMSProp,
sgd=paddle.optimizer.SGD, )
def build_optimizers(model: nn.Layer, optim='adadelta',
learning_rate=0.01) -> paddle.optimizer:
optim_class = optim_classes.get(optim)
if optim_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {optim}")
else:
optim = optim_class(
parameters=model.parameters(), learning_rate=learning_rate)
optimizers = optim
return optimizers
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if not paddle.is_compiled_with_cuda:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=[
"text",
"text_lengths",
"speech",
"speech_lengths",
"durations",
"pitch",
"energy",
# "durations_lengths",
# "pitch_lengths",
# "energy_lengths"
],
converters={
"speech": np.load,
"pitch": np.load,
"energy": np.load,
}, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=[
"text", "text_lengths", "speech", "speech_lengths", "durations",
"pitch", "energy"
],
converters={
"speech": np.load,
"pitch": np.load,
"energy": np.load,
}, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=True)
print("samplers done!")
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_baker_examples,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_baker_examples,
num_workers=config.num_workers)
print("dataloaders done!")
vocab_size = 202
odim = config.n_mels
model = FastSpeech2(idim=vocab_size, odim=odim, **config["model"])
if world_size > 1:
model = DataParallel(model) # TODO, do not use vocab size from config
# print(model)
print("model done!")
optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!")
updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
**config["updater"])
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **
config["updater"])
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
writer = LogWriter(str(output_dir))
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions)
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
"model with Baker Mandrin TTS dataset.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config")
parser.add_argument("--train-metadata", type=str, help="training data")
parser.add_argument("--dev-metadata", type=str, help="dev data")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument(
"--nprocs", type=int, default=1, help="number of processes")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
if args.device == "cpu" and args.nprocs > 1:
raise RuntimeError("Multiprocess training on CPU is not supported.")
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.nprocs > 1:
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

View File

@ -126,6 +126,9 @@ class FastSpeech2(nn.Layer):
# use idx 0 as padding idx
self.padding_idx = 0
# initialize parameters
initialize(self, init_type)
# get positional encoding class
pos_enc_class = (ScaledPositionalEncoding
if self.use_scaled_pos_enc else PositionalEncoding)
@ -135,7 +138,7 @@ class FastSpeech2(nn.Layer):
num_embeddings=idim,
embedding_dim=adim,
padding_idx=self.padding_idx)
print("encoder_type:", encoder_type)
if encoder_type == "transformer":
self.encoder = TransformerEncoder(
idim=idim,
@ -233,7 +236,6 @@ class FastSpeech2(nn.Layer):
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, ))
# initialize parameters
self._reset_parameters(
init_type=init_type,
init_enc_alpha=init_enc_alpha,
@ -250,11 +252,8 @@ class FastSpeech2(nn.Layer):
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
durations: paddle.Tensor,
durations_lengths: paddle.Tensor,
pitch: paddle.Tensor,
pitch_lengths: paddle.Tensor,
energy: paddle.Tensor,
energy_lengths: paddle.Tensor, ) -> Tuple[paddle.Tensor, Dict[
energy: paddle.Tensor, ) -> Tuple[paddle.Tensor, Dict[
str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation.
@ -270,33 +269,33 @@ class FastSpeech2(nn.Layer):
Batch of the lengths of each target (B,).
durations : LongTensor
Batch of padded durations (B, Tmax + 1).
durations_lengths : LongTensor
Batch of duration lengths (B, Tmax + 1).
pitch : Tensor
Batch of padded token-averaged pitch (B, Tmax + 1, 1).
pitch_lengths : LongTensor
Batch of pitch lengths (B, Tmax + 1).
energy : Tensor
Batch of padded token-averaged energy (B, Tmax + 1, 1).
energy_lengths : LongTensor
Batch of energy lengths (B, Tmax + 1).
Returns
----------
Tensor
Loss scalar value.
Dict
Statistics to be monitored.
mel outs before postnet
Tensor
mel outs after postnet
Tensor
duration predictor's output
Tensor
pitch predictor's output
Tensor
energy predictor's output
Tensor
speech
Tensor
real text_lengths
Tensor
speech_lengths, modified if reduction_factor >1
"""
text = text[:, :text_lengths.max()] # for data-parallel
speech = speech[:, :speech_lengths.max()] # for data-parallel
durations = durations[:, :durations_lengths.max()] # for data-parallel
pitch = pitch[:, :pitch_lengths.max()] # for data-parallel
energy = energy[:, :energy_lengths.max()] # for data-parallel
batch_size = text.shape[0]
# Add eos at the last of sequence
# xs = F.pad(text, [0, 1], "constant", self.padding_idx)
xs = np.pad(text.numpy(),
pad_width=((0, 0), (0, 1)),
mode="constant",
@ -319,39 +318,8 @@ class FastSpeech2(nn.Layer):
])
max_olen = max(olens)
ys = ys[:, :max_olen]
# calculate loss
if self.postnet is None:
after_outs = None
# calculate loss
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=ds,
ps=ps,
es=es,
ilens=ilens,
olens=olens, )
loss = l1_loss + duration_loss + pitch_loss + energy_loss
stats = dict(
l1_loss=l1_loss.item(),
duration_loss=duration_loss.item(),
pitch_loss=pitch_loss.item(),
energy_loss=energy_loss.item(),
loss=loss.item(), )
# report extra information
if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
stats.update(encoder_alpha=self.encoder.embed[-1].alpha.item(), )
if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
stats.update(decoder_alpha=self.decoder.embed[-1].alpha.item(), )
# loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, ilens, olens
def _forward(
self,
@ -383,7 +351,6 @@ class FastSpeech2(nn.Layer):
if is_inference:
d_outs = self.duration_predictor.inference(hs,
d_masks) # (B, Tmax)
# print("d_outs:",d_outs)
# use prediction in inference
# (B, Tmax, 1)
@ -396,7 +363,6 @@ class FastSpeech2(nn.Layer):
else:
d_outs = self.duration_predictor(hs, d_masks)
# use groundtruth in training
print("ps.shape:", ps.shape)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
@ -534,8 +500,6 @@ class FastSpeech2(nn.Layer):
init_type: str,
init_enc_alpha: float,
init_dec_alpha: float):
# initialize parameters
initialize(self, init_type)
# initialize alpha in scaled positional encoding
if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
@ -637,19 +601,24 @@ class FastSpeech2Loss(nn.Layer):
"""
# apply mask to remove padded part
if self.use_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
before_outs = before_outs.masked_select(out_masks)
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape))
if after_outs is not None:
after_outs = after_outs.masked_select(out_masks)
ys = ys.masked_select(out_masks)
duration_masks = make_non_pad_mask(ilens).to(ys.device)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device)
p_outs = p_outs.masked_select(pitch_masks)
e_outs = e_outs.masked_select(pitch_masks)
ps = ps.masked_select(pitch_masks)
es = es.masked_select(pitch_masks)
after_outs = after_outs.masked_select(
out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
duration_masks = make_non_pad_mask(ilens)
d_outs = d_outs.masked_select(
duration_masks.broadcast_to(d_outs.shape))
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1)
p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape))
e_outs = e_outs.masked_select(
pitch_masks.broadcast_to(e_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)