complete theexample for speedyspeech; fix several bugs in training module
This commit is contained in:
parent
6c21d80025
commit
4a7888b8c6
|
@ -17,21 +17,27 @@ from parakeet.data.batch import batch_sequences
|
|||
|
||||
|
||||
def collate_baker_examples(examples):
|
||||
# fields = ["phones", "tones", "num_phones", "num_frames", "feats"]
|
||||
# fields = ["phones", "tones", "num_phones", "num_frames", "feats", "durations"]
|
||||
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
|
||||
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
|
||||
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
|
||||
durations = [
|
||||
np.array(
|
||||
item["durations"], dtype=np.int64) for item in examples
|
||||
]
|
||||
num_phones = np.array([item["num_phones"] for item in examples])
|
||||
num_frames = np.array([item["num_frames"] for item in examples])
|
||||
|
||||
phones = batch_sequences(phones)
|
||||
tones = batch_sequences(tones)
|
||||
feats = batch_sequences(feats)
|
||||
durations = batch_sequences(durations)
|
||||
batch = {
|
||||
"phones": phones,
|
||||
"tones": tones,
|
||||
"num_phones": num_phones,
|
||||
"num_frames": num_frames,
|
||||
"feats": feats,
|
||||
"durations": durations,
|
||||
}
|
||||
return batch
|
||||
|
|
|
@ -20,8 +20,8 @@ trim_hop_length: 512 # Hop size in trimming.(in samples)
|
|||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 16
|
||||
num_workers: 0
|
||||
batch_size: 32
|
||||
num_workers: 4
|
||||
|
||||
|
||||
|
||||
|
@ -30,8 +30,8 @@ num_workers: 0
|
|||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model:
|
||||
vocab_size: 68
|
||||
tone_size: 6
|
||||
vocab_size: 101 # 99 + 2
|
||||
tone_size: 8 # 6 + 2
|
||||
encoder_hidden_size: 128
|
||||
encoder_kernel_size: 3
|
||||
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||
|
@ -39,7 +39,7 @@ model:
|
|||
decoder_hidden_size: 128
|
||||
decoder_output_size: 80
|
||||
decoder_kernel_size: 3
|
||||
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||
|
||||
|
||||
###########################################################
|
||||
|
@ -47,6 +47,12 @@ model:
|
|||
###########################################################
|
||||
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 300
|
||||
num_snapshots: 5
|
||||
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
|
|
|
@ -17,7 +17,7 @@ from paddle.nn import functional as F
|
|||
from paddle.fluid.layers import huber_loss
|
||||
|
||||
from parakeet.modules.ssim import ssim
|
||||
from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean
|
||||
from parakeet.modules.losses import masked_l1_loss, weighted_mean
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
|
@ -27,23 +27,25 @@ from parakeet.models.speedyspeech import SpeedySpeech
|
|||
class SpeedySpeechUpdater(StandardUpdater):
|
||||
def update_core(self, batch):
|
||||
decoded, predicted_durations = self.model(
|
||||
text=batch["phonemes"],
|
||||
text=batch["phones"],
|
||||
tones=batch["tones"],
|
||||
plens=batch["phoneme_lenghts"],
|
||||
durations=batch["phoneme_durations"])
|
||||
plens=batch["num_phones"],
|
||||
durations=batch["durations"])
|
||||
|
||||
target_mel = batch["mel"]
|
||||
target_mel = batch["feats"]
|
||||
spec_mask = F.sequence_mask(
|
||||
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
|
||||
text_mask = F.sequence_mask(
|
||||
batch["phoneme_lenghts"], dtype=predicted_durations.dtype)
|
||||
batch["num_phones"], dtype=predicted_durations.dtype)
|
||||
|
||||
# spec loss
|
||||
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
|
||||
|
||||
# duration loss
|
||||
target_durations = batch["phoneme_durations"]
|
||||
target_durations = paddle.clip(target_durations, min=1.0)
|
||||
target_durations = batch["durations"]
|
||||
target_durations = paddle.maximum(
|
||||
target_durations.astype(predicted_durations.dtype),
|
||||
paddle.to_tensor([1.0]))
|
||||
duration_loss = weighted_mean(
|
||||
huber_loss(
|
||||
predicted_durations, paddle.log(target_durations), delta=1.0),
|
||||
|
@ -53,13 +55,57 @@ class SpeedySpeechUpdater(StandardUpdater):
|
|||
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
|
||||
(target_mel * spec_mask).unsqueeze(1))
|
||||
|
||||
loss = l1_loss + duration_loss + ssim_loss
|
||||
loss = l1_loss + ssim_loss + duration_loss
|
||||
|
||||
optimizer = self.optimizer
|
||||
optimizer.clear_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
report("train/loss", float(loss))
|
||||
report("train/l1_loss", float(l1_loss))
|
||||
report("train/duration_loss", float(duration_loss))
|
||||
report("train/ssim_loss", float(ssim_loss))
|
||||
|
||||
|
||||
class SpeedySpeechEvaluator(StandardEvaluator):
|
||||
def evaluate_core(self, batch):
|
||||
print("fire")
|
||||
decoded, predicted_durations = self.model(
|
||||
text=batch["phones"],
|
||||
tones=batch["tones"],
|
||||
plens=batch["num_phones"],
|
||||
durations=batch["durations"])
|
||||
|
||||
target_mel = batch["feats"]
|
||||
spec_mask = F.sequence_mask(
|
||||
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
|
||||
text_mask = F.sequence_mask(
|
||||
batch["num_phones"], dtype=predicted_durations.dtype)
|
||||
|
||||
# spec loss
|
||||
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
|
||||
|
||||
# duration loss
|
||||
target_durations = batch["durations"]
|
||||
target_durations = paddle.maximum(
|
||||
target_durations.astype(predicted_durations.dtype),
|
||||
paddle.to_tensor([1.0]))
|
||||
duration_loss = weighted_mean(
|
||||
huber_loss(
|
||||
predicted_durations, paddle.log(target_durations), delta=1.0),
|
||||
text_mask, )
|
||||
|
||||
# ssim loss
|
||||
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
|
||||
(target_mel * spec_mask).unsqueeze(1))
|
||||
|
||||
loss = l1_loss + ssim_loss + duration_loss
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
report("eval/loss", float(loss))
|
||||
report("eval/l1_loss", float(l1_loss))
|
||||
report("eval/duration_loss", float(duration_loss))
|
||||
report("eval/ssim_loss", float(ssim_loss))
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.speedyspeech import SpeedySpeech
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training import extension
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.seeding import seed_everything
|
||||
|
||||
from batch_fn import collate_baker_examples
|
||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def evaluate(args, config):
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for evaluation
|
||||
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||
test_metadata = list(reader)
|
||||
test_dataset = DataTable(
|
||||
data=test_metadata, fields=["utt_id", "phones", "tones"])
|
||||
|
||||
model = SpeedySpeech(**config["model"])
|
||||
model.set_state_dict(paddle.load(args.checkpoint)["main_params"])
|
||||
model.eval()
|
||||
vocoder_config = yaml.safe_load(
|
||||
open("../../parallelwave_gan/baker/conf/default.yaml"))
|
||||
vocoder = PWGGenerator(**vocoder_config["generator_params"])
|
||||
vocoder.set_state_dict(
|
||||
paddle.load("../../parallelwave_gan/baker/converted.pdparams"))
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder.eval()
|
||||
# print(model)
|
||||
print("model done!")
|
||||
|
||||
stat = np.load("../../speedyspeech/baker/dump/train/stats.npy")
|
||||
mu, std = stat
|
||||
mu = paddle.to_tensor(mu)
|
||||
std = paddle.to_tensor(std)
|
||||
|
||||
stat2 = np.load("../../parallelwave_gan/baker/dump/train/stats.npy")
|
||||
mu2, std2 = stat2
|
||||
mu2 = paddle.to_tensor(mu2)
|
||||
std2 = paddle.to_tensor(std2)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for datum in test_dataset:
|
||||
utt_id = datum["utt_id"]
|
||||
phones = paddle.to_tensor(datum["phones"])
|
||||
tones = paddle.to_tensor(datum["tones"])
|
||||
|
||||
mel, _ = model.inference(phones, tones)
|
||||
mel = mel * std + mu
|
||||
mel = (mel - mu2) / std2
|
||||
|
||||
wav = vocoder.inference(mel)
|
||||
sf.write(
|
||||
output_dir / (utt_id + ".wav"), wav.numpy(), samplerate=config.sr)
|
||||
print(f"{utt_id} done!")
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--checkpoint", type=str, help="checkpoint to load.")
|
||||
parser.add_argument("--test-metadata", type=str, help="training data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
|
||||
evaluate(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -44,6 +44,7 @@ from parakeet.training.extensions.visualizer import VisualDL
|
|||
from parakeet.training.seeding import seed_everything
|
||||
|
||||
from batch_fn import collate_baker_examples
|
||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
|
@ -73,26 +74,29 @@ def train_sp(args, config):
|
|||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
|
||||
fields=[
|
||||
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
|
||||
],
|
||||
converters={"feats": np.load, }, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
|
||||
fields=[
|
||||
"phones", "tones", "num_phones", "num_frames", "feats", "durations"
|
||||
],
|
||||
converters={"feats": np.load, }, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dev_sampler = DistributedBatchSampler(
|
||||
dev_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
drop_last=True)
|
||||
# dev_sampler = DistributedBatchSampler(dev_dataset,
|
||||
# batch_size=config.batch_size,
|
||||
# shuffle=False,
|
||||
# drop_last=False)
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
|
@ -102,16 +106,43 @@ def train_sp(args, config):
|
|||
num_workers=config.num_workers)
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=collate_baker_examples,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
|
||||
# batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
||||
# # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
||||
# import pdb; pdb.set_trace()
|
||||
model = SpeedySpeech(**config["model"])
|
||||
print(model)
|
||||
if world_size > 1:
|
||||
model = DataParallel(model) # TODO, do not use vocab size from config
|
||||
# print(model)
|
||||
print("model done!")
|
||||
optimizer = Adam(
|
||||
0.001,
|
||||
parameters=model.parameters(),
|
||||
grad_clip=nn.ClipGradByGlobalNorm(5.0))
|
||||
print("optimizer done!")
|
||||
|
||||
updater = SpeedySpeechUpdater(
|
||||
model=model, optimizer=optimizer, dataloader=train_dataloader)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
writer = LogWriter(str(output_dir))
|
||||
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
print(trainer.extensions)
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -189,7 +189,7 @@ class SpeedySpeech(nn.Layer):
|
|||
|
||||
# decode
|
||||
# remove positional encoding here
|
||||
_, t_dec, feature_size = encodings.shpae
|
||||
_, t_dec, feature_size = encodings.shape
|
||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||
decoded = self.decoder(encodings)
|
||||
return decoded, pred_durations
|
||||
|
@ -211,4 +211,4 @@ class SpeedySpeech(nn.Layer):
|
|||
t_dec, feature_size = shape[1], shape[2]
|
||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||
decoded = self.decoder(encodings)
|
||||
return decoded, pred_durations
|
||||
return decoded[0], pred_durations[0]
|
||||
|
|
|
@ -123,8 +123,6 @@ class Trainer(object):
|
|||
update = self.updater.update # training step
|
||||
stop_trigger = self.stop_trigger
|
||||
|
||||
print(self.updater.state)
|
||||
|
||||
# display only one progress bar
|
||||
max_iteration = None
|
||||
if isinstance(stop_trigger, LimitTrigger):
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class IntervalTrigger(object):
|
||||
"""A Predicate to do something every N cycle."""
|
||||
|
@ -23,9 +25,16 @@ class IntervalTrigger(object):
|
|||
raise ValueError("period should be a positive integer.")
|
||||
self.period = period
|
||||
self.unit = unit
|
||||
self.last_index = None
|
||||
|
||||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
index = getattr(state, self.unit)
|
||||
fire = index % self.period == 0
|
||||
if self.last_index is None:
|
||||
last_index = getattr(trainer.updater.state, self.unit)
|
||||
self.last_index = last_index
|
||||
|
||||
last_index = self.last_index
|
||||
index = getattr(trainer.updater.state, self.unit)
|
||||
fire = index // self.period != last_index // self.period
|
||||
|
||||
self.last_index = index
|
||||
return fire
|
||||
|
|
|
@ -106,8 +106,8 @@ class StandardUpdater(UpdaterBase):
|
|||
self.update_core(batch)
|
||||
|
||||
self.state.iteration += 1
|
||||
if self.updaters_per_epoch is not None:
|
||||
if self.state.iteration % self.updaters_per_epoch == 0:
|
||||
if self.updates_per_epoch is not None:
|
||||
if self.state.iteration % self.updates_per_epoch == 0:
|
||||
self.state.epoch += 1
|
||||
|
||||
def update_core(self, batch):
|
||||
|
@ -139,7 +139,7 @@ class StandardUpdater(UpdaterBase):
|
|||
self.optimizer.update()
|
||||
|
||||
@property
|
||||
def updaters_per_epoch(self):
|
||||
def updates_per_epoch(self):
|
||||
"""Number of updater per epoch, determined by the length of the
|
||||
dataloader."""
|
||||
length_of_dataloader = None
|
||||
|
|
Loading…
Reference in New Issue