complete theexample for speedyspeech; fix several bugs in training module

This commit is contained in:
chenfeiyu 2021-07-12 15:19:22 +08:00
parent 6c21d80025
commit 4a7888b8c6
9 changed files with 263 additions and 36 deletions

View File

@ -17,21 +17,27 @@ from parakeet.data.batch import batch_sequences
def collate_baker_examples(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats"]
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
durations = [
np.array(
item["durations"], dtype=np.int64) for item in examples
]
num_phones = np.array([item["num_phones"] for item in examples])
num_frames = np.array([item["num_frames"] for item in examples])
phones = batch_sequences(phones)
tones = batch_sequences(tones)
feats = batch_sequences(feats)
durations = batch_sequences(durations)
batch = {
"phones": phones,
"tones": tones,
"num_phones": num_phones,
"num_frames": num_frames,
"feats": feats,
"durations": durations,
}
return batch

View File

@ -20,8 +20,8 @@ trim_hop_length: 512 # Hop size in trimming.(in samples)
###########################################################
# DATA SETTING #
###########################################################
batch_size: 16
num_workers: 0
batch_size: 32
num_workers: 4
@ -30,8 +30,8 @@ num_workers: 0
# MODEL SETTING #
###########################################################
model:
vocab_size: 68
tone_size: 6
vocab_size: 101 # 99 + 2
tone_size: 8 # 6 + 2
encoder_hidden_size: 128
encoder_kernel_size: 3
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
@ -39,7 +39,7 @@ model:
decoder_hidden_size: 128
decoder_output_size: 80
decoder_kernel_size: 3
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
###########################################################
@ -47,6 +47,12 @@ model:
###########################################################
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 300
num_snapshots: 5
###########################################################
# OTHER SETTING #

View File

@ -17,7 +17,7 @@ from paddle.nn import functional as F
from paddle.fluid.layers import huber_loss
from parakeet.modules.ssim import ssim
from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean
from parakeet.modules.losses import masked_l1_loss, weighted_mean
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.extensions.evaluator import StandardEvaluator
@ -27,23 +27,25 @@ from parakeet.models.speedyspeech import SpeedySpeech
class SpeedySpeechUpdater(StandardUpdater):
def update_core(self, batch):
decoded, predicted_durations = self.model(
text=batch["phonemes"],
text=batch["phones"],
tones=batch["tones"],
plens=batch["phoneme_lenghts"],
durations=batch["phoneme_durations"])
plens=batch["num_phones"],
durations=batch["durations"])
target_mel = batch["mel"]
target_mel = batch["feats"]
spec_mask = F.sequence_mask(
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
text_mask = F.sequence_mask(
batch["phoneme_lenghts"], dtype=predicted_durations.dtype)
batch["num_phones"], dtype=predicted_durations.dtype)
# spec loss
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
# duration loss
target_durations = batch["phoneme_durations"]
target_durations = paddle.clip(target_durations, min=1.0)
target_durations = batch["durations"]
target_durations = paddle.maximum(
target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0]))
duration_loss = weighted_mean(
huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0),
@ -53,13 +55,57 @@ class SpeedySpeechUpdater(StandardUpdater):
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
(target_mel * spec_mask).unsqueeze(1))
loss = l1_loss + duration_loss + ssim_loss
loss = l1_loss + ssim_loss + duration_loss
optimizer = self.optimizer
optimizer.clear_grad()
loss.backward()
optimizer.step()
# import pdb; pdb.set_trace()
report("train/loss", float(loss))
report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss))
report("train/ssim_loss", float(ssim_loss))
class SpeedySpeechEvaluator(StandardEvaluator):
def evaluate_core(self, batch):
print("fire")
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],
plens=batch["num_phones"],
durations=batch["durations"])
target_mel = batch["feats"]
spec_mask = F.sequence_mask(
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
text_mask = F.sequence_mask(
batch["num_phones"], dtype=predicted_durations.dtype)
# spec loss
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
# duration loss
target_durations = batch["durations"]
target_durations = paddle.maximum(
target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0]))
duration_loss = weighted_mean(
huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0),
text_mask, )
# ssim loss
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
(target_mel * spec_mask).unsqueeze(1))
loss = l1_loss + ssim_loss + duration_loss
# import pdb; pdb.set_trace()
report("eval/loss", float(loss))
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/ssim_loss", float(ssim_loss))

View File

@ -0,0 +1,131 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
import dataclasses
from pathlib import Path
import yaml
import jsonlines
import paddle
import numpy as np
import soundfile as sf
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay
from paddle import DataParallel
from visualdl import LogWriter
from parakeet.datasets.data_table import DataTable
from parakeet.models.speedyspeech import SpeedySpeech
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.training.updater import UpdaterBase
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training import extension
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything
from batch_fn import collate_baker_examples
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
from config import get_cfg_default
def evaluate(args, config):
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader)
test_dataset = DataTable(
data=test_metadata, fields=["utt_id", "phones", "tones"])
model = SpeedySpeech(**config["model"])
model.set_state_dict(paddle.load(args.checkpoint)["main_params"])
model.eval()
vocoder_config = yaml.safe_load(
open("../../parallelwave_gan/baker/conf/default.yaml"))
vocoder = PWGGenerator(**vocoder_config["generator_params"])
vocoder.set_state_dict(
paddle.load("../../parallelwave_gan/baker/converted.pdparams"))
vocoder.remove_weight_norm()
vocoder.eval()
# print(model)
print("model done!")
stat = np.load("../../speedyspeech/baker/dump/train/stats.npy")
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
stat2 = np.load("../../parallelwave_gan/baker/dump/train/stats.npy")
mu2, std2 = stat2
mu2 = paddle.to_tensor(mu2)
std2 = paddle.to_tensor(std2)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for datum in test_dataset:
utt_id = datum["utt_id"]
phones = paddle.to_tensor(datum["phones"])
tones = paddle.to_tensor(datum["tones"])
mel, _ = model.inference(phones, tones)
mel = mel * std + mu
mel = (mel - mu2) / std2
wav = vocoder.inference(mel)
sf.write(
output_dir / (utt_id + ".wav"), wav.numpy(), samplerate=config.sr)
print(f"{utt_id} done!")
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
"model with Baker Mandrin TTS dataset.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config")
parser.add_argument("--checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("--test-metadata", type=str, help="training data")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
evaluate(args, config)
if __name__ == "__main__":
main()

View File

@ -44,6 +44,7 @@ from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything
from batch_fn import collate_baker_examples
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
from config import get_cfg_default
@ -73,26 +74,29 @@ def train_sp(args, config):
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
fields=[
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={"feats": np.load, }, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
fields=[
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={"feats": np.load, }, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
drop_last=True)
# dev_sampler = DistributedBatchSampler(dev_dataset,
# batch_size=config.batch_size,
# shuffle=False,
# drop_last=False)
print("samplers done!")
train_dataloader = DataLoader(
@ -102,16 +106,43 @@ def train_sp(args, config):
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_baker_examples,
num_workers=config.num_workers)
print("dataloaders done!")
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
# batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
# # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
# import pdb; pdb.set_trace()
model = SpeedySpeech(**config["model"])
print(model)
if world_size > 1:
model = DataParallel(model) # TODO, do not use vocab size from config
# print(model)
print("model done!")
optimizer = Adam(
0.001,
parameters=model.parameters(),
grad_clip=nn.ClipGradByGlobalNorm(5.0))
print("optimizer done!")
updater = SpeedySpeechUpdater(
model=model, optimizer=optimizer, dataloader=train_dataloader)
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
writer = LogWriter(str(output_dir))
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions)
trainer.run()
def main():

View File

@ -189,7 +189,7 @@ class SpeedySpeech(nn.Layer):
# decode
# remove positional encoding here
_, t_dec, feature_size = encodings.shpae
_, t_dec, feature_size = encodings.shape
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings)
return decoded, pred_durations
@ -211,4 +211,4 @@ class SpeedySpeech(nn.Layer):
t_dec, feature_size = shape[1], shape[2]
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings)
return decoded, pred_durations
return decoded[0], pred_durations[0]

View File

@ -123,8 +123,6 @@ class Trainer(object):
update = self.updater.update # training step
stop_trigger = self.stop_trigger
print(self.updater.state)
# display only one progress bar
max_iteration = None
if isinstance(stop_trigger, LimitTrigger):

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
class IntervalTrigger(object):
"""A Predicate to do something every N cycle."""
@ -23,9 +25,16 @@ class IntervalTrigger(object):
raise ValueError("period should be a positive integer.")
self.period = period
self.unit = unit
self.last_index = None
def __call__(self, trainer):
state = trainer.updater.state
index = getattr(state, self.unit)
fire = index % self.period == 0
if self.last_index is None:
last_index = getattr(trainer.updater.state, self.unit)
self.last_index = last_index
last_index = self.last_index
index = getattr(trainer.updater.state, self.unit)
fire = index // self.period != last_index // self.period
self.last_index = index
return fire

View File

@ -106,8 +106,8 @@ class StandardUpdater(UpdaterBase):
self.update_core(batch)
self.state.iteration += 1
if self.updaters_per_epoch is not None:
if self.state.iteration % self.updaters_per_epoch == 0:
if self.updates_per_epoch is not None:
if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1
def update_core(self, batch):
@ -139,7 +139,7 @@ class StandardUpdater(UpdaterBase):
self.optimizer.update()
@property
def updaters_per_epoch(self):
def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader = None