complete theexample for speedyspeech; fix several bugs in training module

This commit is contained in:
chenfeiyu 2021-07-12 15:19:22 +08:00
parent 6c21d80025
commit 4a7888b8c6
9 changed files with 263 additions and 36 deletions

View File

@ -17,21 +17,27 @@ from parakeet.data.batch import batch_sequences
def collate_baker_examples(examples): def collate_baker_examples(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats"] # fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples] phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples] tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples] feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
durations = [
np.array(
item["durations"], dtype=np.int64) for item in examples
]
num_phones = np.array([item["num_phones"] for item in examples]) num_phones = np.array([item["num_phones"] for item in examples])
num_frames = np.array([item["num_frames"] for item in examples]) num_frames = np.array([item["num_frames"] for item in examples])
phones = batch_sequences(phones) phones = batch_sequences(phones)
tones = batch_sequences(tones) tones = batch_sequences(tones)
feats = batch_sequences(feats) feats = batch_sequences(feats)
durations = batch_sequences(durations)
batch = { batch = {
"phones": phones, "phones": phones,
"tones": tones, "tones": tones,
"num_phones": num_phones, "num_phones": num_phones,
"num_frames": num_frames, "num_frames": num_frames,
"feats": feats, "feats": feats,
"durations": durations,
} }
return batch return batch

View File

@ -20,8 +20,8 @@ trim_hop_length: 512 # Hop size in trimming.(in samples)
########################################################### ###########################################################
# DATA SETTING # # DATA SETTING #
########################################################### ###########################################################
batch_size: 16 batch_size: 32
num_workers: 0 num_workers: 4
@ -30,8 +30,8 @@ num_workers: 0
# MODEL SETTING # # MODEL SETTING #
########################################################### ###########################################################
model: model:
vocab_size: 68 vocab_size: 101 # 99 + 2
tone_size: 6 tone_size: 8 # 6 + 2
encoder_hidden_size: 128 encoder_hidden_size: 128
encoder_kernel_size: 3 encoder_kernel_size: 3
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1] encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
@ -39,7 +39,7 @@ model:
decoder_hidden_size: 128 decoder_hidden_size: 128
decoder_output_size: 80 decoder_output_size: 80
decoder_kernel_size: 3 decoder_kernel_size: 3
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1] decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
########################################################### ###########################################################
@ -47,6 +47,12 @@ model:
########################################################### ###########################################################
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 300
num_snapshots: 5
########################################################### ###########################################################
# OTHER SETTING # # OTHER SETTING #

View File

@ -17,7 +17,7 @@ from paddle.nn import functional as F
from paddle.fluid.layers import huber_loss from paddle.fluid.layers import huber_loss
from parakeet.modules.ssim import ssim from parakeet.modules.ssim import ssim
from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean from parakeet.modules.losses import masked_l1_loss, weighted_mean
from parakeet.training.reporter import report from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.extensions.evaluator import StandardEvaluator
@ -27,23 +27,25 @@ from parakeet.models.speedyspeech import SpeedySpeech
class SpeedySpeechUpdater(StandardUpdater): class SpeedySpeechUpdater(StandardUpdater):
def update_core(self, batch): def update_core(self, batch):
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phonemes"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
plens=batch["phoneme_lenghts"], plens=batch["num_phones"],
durations=batch["phoneme_durations"]) durations=batch["durations"])
target_mel = batch["mel"] target_mel = batch["feats"]
spec_mask = F.sequence_mask( spec_mask = F.sequence_mask(
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1) batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
text_mask = F.sequence_mask( text_mask = F.sequence_mask(
batch["phoneme_lenghts"], dtype=predicted_durations.dtype) batch["num_phones"], dtype=predicted_durations.dtype)
# spec loss # spec loss
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask) l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
# duration loss # duration loss
target_durations = batch["phoneme_durations"] target_durations = batch["durations"]
target_durations = paddle.clip(target_durations, min=1.0) target_durations = paddle.maximum(
target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0]))
duration_loss = weighted_mean( duration_loss = weighted_mean(
huber_loss( huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0), predicted_durations, paddle.log(target_durations), delta=1.0),
@ -53,13 +55,57 @@ class SpeedySpeechUpdater(StandardUpdater):
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1), ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
(target_mel * spec_mask).unsqueeze(1)) (target_mel * spec_mask).unsqueeze(1))
loss = l1_loss + duration_loss + ssim_loss loss = l1_loss + ssim_loss + duration_loss
optimizer = self.optimizer optimizer = self.optimizer
optimizer.clear_grad() optimizer.clear_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# import pdb; pdb.set_trace()
report("train/loss", float(loss))
report("train/l1_loss", float(l1_loss)) report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss)) report("train/duration_loss", float(duration_loss))
report("train/ssim_loss", float(ssim_loss)) report("train/ssim_loss", float(ssim_loss))
class SpeedySpeechEvaluator(StandardEvaluator):
def evaluate_core(self, batch):
print("fire")
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],
plens=batch["num_phones"],
durations=batch["durations"])
target_mel = batch["feats"]
spec_mask = F.sequence_mask(
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
text_mask = F.sequence_mask(
batch["num_phones"], dtype=predicted_durations.dtype)
# spec loss
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
# duration loss
target_durations = batch["durations"]
target_durations = paddle.maximum(
target_durations.astype(predicted_durations.dtype),
paddle.to_tensor([1.0]))
duration_loss = weighted_mean(
huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0),
text_mask, )
# ssim loss
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
(target_mel * spec_mask).unsqueeze(1))
loss = l1_loss + ssim_loss + duration_loss
# import pdb; pdb.set_trace()
report("eval/loss", float(loss))
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/ssim_loss", float(ssim_loss))

View File

@ -0,0 +1,131 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
import dataclasses
from pathlib import Path
import yaml
import jsonlines
import paddle
import numpy as np
import soundfile as sf
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay
from paddle import DataParallel
from visualdl import LogWriter
from parakeet.datasets.data_table import DataTable
from parakeet.models.speedyspeech import SpeedySpeech
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.training.updater import UpdaterBase
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training import extension
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything
from batch_fn import collate_baker_examples
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
from config import get_cfg_default
def evaluate(args, config):
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader)
test_dataset = DataTable(
data=test_metadata, fields=["utt_id", "phones", "tones"])
model = SpeedySpeech(**config["model"])
model.set_state_dict(paddle.load(args.checkpoint)["main_params"])
model.eval()
vocoder_config = yaml.safe_load(
open("../../parallelwave_gan/baker/conf/default.yaml"))
vocoder = PWGGenerator(**vocoder_config["generator_params"])
vocoder.set_state_dict(
paddle.load("../../parallelwave_gan/baker/converted.pdparams"))
vocoder.remove_weight_norm()
vocoder.eval()
# print(model)
print("model done!")
stat = np.load("../../speedyspeech/baker/dump/train/stats.npy")
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
stat2 = np.load("../../parallelwave_gan/baker/dump/train/stats.npy")
mu2, std2 = stat2
mu2 = paddle.to_tensor(mu2)
std2 = paddle.to_tensor(std2)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for datum in test_dataset:
utt_id = datum["utt_id"]
phones = paddle.to_tensor(datum["phones"])
tones = paddle.to_tensor(datum["tones"])
mel, _ = model.inference(phones, tones)
mel = mel * std + mu
mel = (mel - mu2) / std2
wav = vocoder.inference(mel)
sf.write(
output_dir / (utt_id + ".wav"), wav.numpy(), samplerate=config.sr)
print(f"{utt_id} done!")
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
"model with Baker Mandrin TTS dataset.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config")
parser.add_argument("--checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("--test-metadata", type=str, help="training data")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
evaluate(args, config)
if __name__ == "__main__":
main()

View File

@ -44,6 +44,7 @@ from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything from parakeet.training.seeding import seed_everything
from batch_fn import collate_baker_examples from batch_fn import collate_baker_examples
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
from config import get_cfg_default from config import get_cfg_default
@ -73,26 +74,29 @@ def train_sp(args, config):
train_metadata = list(reader) train_metadata = list(reader)
train_dataset = DataTable( train_dataset = DataTable(
data=train_metadata, data=train_metadata,
fields=["phones", "tones", "num_phones", "num_frames", "feats"], fields=[
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={"feats": np.load, }, ) converters={"feats": np.load, }, )
with jsonlines.open(args.dev_metadata, 'r') as reader: with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader) dev_metadata = list(reader)
dev_dataset = DataTable( dev_dataset = DataTable(
data=dev_metadata, data=dev_metadata,
fields=["phones", "tones", "num_phones", "num_frames", "feats"], fields=[
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
],
converters={"feats": np.load, }, ) converters={"feats": np.load, }, )
# collate function and dataloader # collate function and dataloader
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False, shuffle=False,
drop_last=False) drop_last=True)
# dev_sampler = DistributedBatchSampler(dev_dataset,
# batch_size=config.batch_size,
# shuffle=False,
# drop_last=False)
print("samplers done!") print("samplers done!")
train_dataloader = DataLoader( train_dataloader = DataLoader(
@ -102,16 +106,43 @@ def train_sp(args, config):
num_workers=config.num_workers) num_workers=config.num_workers)
dev_dataloader = DataLoader( dev_dataloader = DataLoader(
dev_dataset, dev_dataset,
batch_sampler=dev_sampler, shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_baker_examples, collate_fn=collate_baker_examples,
num_workers=config.num_workers) num_workers=config.num_workers)
print("dataloaders done!") print("dataloaders done!")
# batch = collate_baker_examples([train_dataset[i] for i in range(10)]) # batch = collate_baker_examples([train_dataset[i] for i in range(10)])
# batch = collate_baker_examples([dev_dataset[i] for i in range(10)]) # # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
model = SpeedySpeech(**config["model"]) model = SpeedySpeech(**config["model"])
print(model) if world_size > 1:
model = DataParallel(model) # TODO, do not use vocab size from config
# print(model)
print("model done!")
optimizer = Adam(
0.001,
parameters=model.parameters(),
grad_clip=nn.ClipGradByGlobalNorm(5.0))
print("optimizer done!")
updater = SpeedySpeechUpdater(
model=model, optimizer=optimizer, dataloader=train_dataloader)
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
writer = LogWriter(str(output_dir))
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions)
trainer.run()
def main(): def main():

View File

@ -189,7 +189,7 @@ class SpeedySpeech(nn.Layer):
# decode # decode
# remove positional encoding here # remove positional encoding here
_, t_dec, feature_size = encodings.shpae _, t_dec, feature_size = encodings.shape
encodings += sinusoid_position_encoding(t_dec, feature_size) encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded, pred_durations return decoded, pred_durations
@ -211,4 +211,4 @@ class SpeedySpeech(nn.Layer):
t_dec, feature_size = shape[1], shape[2] t_dec, feature_size = shape[1], shape[2]
encodings += sinusoid_position_encoding(t_dec, feature_size) encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings) decoded = self.decoder(encodings)
return decoded, pred_durations return decoded[0], pred_durations[0]

View File

@ -123,8 +123,6 @@ class Trainer(object):
update = self.updater.update # training step update = self.updater.update # training step
stop_trigger = self.stop_trigger stop_trigger = self.stop_trigger
print(self.updater.state)
# display only one progress bar # display only one progress bar
max_iteration = None max_iteration = None
if isinstance(stop_trigger, LimitTrigger): if isinstance(stop_trigger, LimitTrigger):

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy
class IntervalTrigger(object): class IntervalTrigger(object):
"""A Predicate to do something every N cycle.""" """A Predicate to do something every N cycle."""
@ -23,9 +25,16 @@ class IntervalTrigger(object):
raise ValueError("period should be a positive integer.") raise ValueError("period should be a positive integer.")
self.period = period self.period = period
self.unit = unit self.unit = unit
self.last_index = None
def __call__(self, trainer): def __call__(self, trainer):
state = trainer.updater.state if self.last_index is None:
index = getattr(state, self.unit) last_index = getattr(trainer.updater.state, self.unit)
fire = index % self.period == 0 self.last_index = last_index
last_index = self.last_index
index = getattr(trainer.updater.state, self.unit)
fire = index // self.period != last_index // self.period
self.last_index = index
return fire return fire

View File

@ -106,8 +106,8 @@ class StandardUpdater(UpdaterBase):
self.update_core(batch) self.update_core(batch)
self.state.iteration += 1 self.state.iteration += 1
if self.updaters_per_epoch is not None: if self.updates_per_epoch is not None:
if self.state.iteration % self.updaters_per_epoch == 0: if self.state.iteration % self.updates_per_epoch == 0:
self.state.epoch += 1 self.state.epoch += 1
def update_core(self, batch): def update_core(self, batch):
@ -139,7 +139,7 @@ class StandardUpdater(UpdaterBase):
self.optimizer.update() self.optimizer.update()
@property @property
def updaters_per_epoch(self): def updates_per_epoch(self):
"""Number of updater per epoch, determined by the length of the """Number of updater per epoch, determined by the length of the
dataloader.""" dataloader."""
length_of_dataloader = None length_of_dataloader = None