update save & load for deep voicde 3, wavenet and clarinet, remove the concept of epoch in training

This commit is contained in:
chenfeiyu 2020-03-24 08:53:40 +00:00
parent 64790853e5
commit 776743530a
11 changed files with 272 additions and 233 deletions

View File

@ -28,7 +28,7 @@ Train the model using train.py, follow the usage displayed by `python train.py -
```text
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
[--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET]
train a ClariNet model with LJspeech and a trained WaveNet model.
@ -38,13 +38,13 @@ optional arguments:
--device DEVICE device to use.
--output OUTPUT path to save student.
--data DATA path of LJspeech dataset.
--resume RESUME checkpoint to load from.
--checkpoint CHECKPOINT checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use.
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
@ -53,6 +53,8 @@ optional arguments:
└── log # tensorboard log
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.

View File

@ -31,7 +31,7 @@ from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model
from utils import valid_model, eval_model, load_model
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector

View File

@ -30,14 +30,15 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet
from utils import make_output_tree, valid_model, load_wavenet
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="train a clarinet model with LJspeech and a trained wavenet model."
description="train a ClariNet model with LJspeech and a trained WaveNet model."
)
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
@ -48,13 +49,18 @@ if __name__ == "__main__":
default="experiment",
help="path to save student.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument("--resume", type=str, help="checkpoint to load from.")
parser.add_argument(
"--checkpoint", type=str, help="checkpoint to load from.")
parser.add_argument(
"--wavenet", type=str, help="wavenet checkpoint to use.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
@ -154,12 +160,38 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
if args.wavenet:
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# load wavenet/checkpoint, determine iterations done
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split('-')[-1])
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
if iteration == 0 and args.wavenet is None:
raise Exception(
"you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
)
if args.wavenet is not None and iteration > 0:
if args.checkpoint is None:
print("Resume training, --wavenet ignored")
else:
print("--checkpoint provided, --wavenet ignored")
if args.wavenet is not None and iteration == 0:
load_wavenet(model, args.wavenet)
if args.resume:
load_checkpoint(model, optim, args.resume)
# it may overwrite the wavenet loaded
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
# loader
train_loader = fluid.io.DataLoader.from_generator(
@ -170,21 +202,16 @@ if __name__ == "__main__":
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# training loop
global_step = 1
global_epoch = 1
global_step = iteration + 1
iterator = iter(tqdm(train_loader))
while global_step < max_iterations:
epoch_loss = 0.
for j, batch in tqdm(enumerate(train_loader), desc="[train]"):
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(train_loader))
batch = next(iterator)
audios, mels, audio_starts = batch
model.train()
loss_dict = model(
@ -200,7 +227,6 @@ if __name__ == "__main__":
l = loss_dict["loss"]
step_loss = l.numpy()[0]
print("[train] loss: {:<8.6f}".format(step_loss))
epoch_loss += step_loss
l.backward()
optim.minimize(l, grad_clip=clipper)
@ -211,11 +237,8 @@ if __name__ == "__main__":
valid_model(model, valid_loader, state_dir, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
global_step += 1
# epoch loss
average_loss = epoch_loss / j
writer.add_scalar("average_loss", average_loss, global_epoch)
global_epoch += 1

View File

@ -35,26 +35,23 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE]
usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT]
[-o OUTPUT] [-g DEVICE]
Train a Deep Voice 3 model with LJSpeech dataset.
optional arguments:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG
experimrnt config
-c CONFIG, --config CONFIG experimrnt config
-s DATA, --data DATA The path of the LJSpeech dataset.
-r RESUME, --resume RESUME
checkpoint to load
-o OUTPUT, --output OUTPUT
The directory to save result.
-g DEVICE, --device DEVICE
device to use
--checkpoint CHECKPOINT checkpoint to load
-o OUTPUT, --output OUTPUT The directory to save result.
-g DEVICE, --device DEVICE device to use
```
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
```text
@ -67,6 +64,8 @@ optional arguments:
└── waveform # waveform (.wav files)
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:

View File

@ -83,7 +83,7 @@ lr_scheduler:
train:
batch_size: 16
epochs: 2000
max_iteration: 2000000
snap_interval: 1000
eval_interval: 10000

View File

@ -25,8 +25,9 @@ import paddle.fluid.dygraph as dg
from tensorboardX import SummaryWriter
from parakeet.g2p import en
from parakeet.utils.layer_tools import summary
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.utils.layer_tools import summary
from parakeet.utils.io import load_parameters
from utils import make_model, eval_model, plot_alignment
@ -44,6 +45,10 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
if args.device == -1:
place = fluid.CPUPlace()
else:

View File

@ -17,6 +17,8 @@ import os
import argparse
import ruamel.yaml
import numpy as np
import matplotlib
matplotlib.use("agg")
from matplotlib import cm
import matplotlib.pyplot as plt
import tqdm
@ -35,13 +37,14 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler,
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.")
description="Train a Deep Voice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument(
"-s",
@ -49,7 +52,7 @@ if __name__ == "__main__":
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
parser.add_argument("--checkpoint", type=str, help="checkpoint to load")
parser.add_argument(
"-o",
"--output",
@ -62,6 +65,10 @@ if __name__ == "__main__":
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
# =========================dataset=========================
# construct meta data
data_root = args.data
@ -151,6 +158,7 @@ if __name__ == "__main__":
query_position_rate, key_position_rate, window_backward,
window_ahead, key_projection, value_projection, downsample_factor,
linear_dim, use_decoder_states, converter_channels, dropout)
summary(dv3)
# =========================loss=========================
loss_config = config["loss"]
@ -195,7 +203,6 @@ if __name__ == "__main__":
n_iter = synthesis_config["n_iter"]
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
@ -208,24 +215,29 @@ if __name__ == "__main__":
make_output_tree(output_dir)
writer = SummaryWriter(logdir=log_dir)
# load model parameters
resume_path = args.resume
if resume_path is not None:
state, _ = dg.load_dygraph(args.resume)
dv3.set_dict(state)
# load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
else:
iteration = io.load_latest_checkpoint(ckpt_dir)
# =========================train=========================
epoch = train_config["epochs"]
max_iter = train_config["max_iteration"]
snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"]
global_step = 1
global_step = iteration + 1
iterator = iter(tqdm.tqdm(loader))
while global_step <= max_iter:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(loader))
batch = next(iterator)
for j in range(1, 1 + epoch):
epoch_loss = 0.
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
dv3.train() # CAUTION: don't forget to switch to train
dv3.train()
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
@ -238,26 +250,25 @@ if __name__ == "__main__":
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames)
losses = criterion(mel_outputs, linear_outputs, done, alignments,
downsampled_mel_specs, lin_specs, done_flags,
text_lengths, frames)
l = losses["loss"]
l.backward()
# record learning rate before updating
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy(),
global_step)
optim._learning_rate.step().numpy(), global_step)
optim.minimize(l, grad_clip=gradient_clipper)
optim.clear_gradients()
# ==================all kinds of tedious things=================
# record step loss into tensorboard
epoch_loss += l.numpy()[0]
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
tqdm.tqdm.write("global_step: {}\tloss: {}".format(
global_step, step_loss["loss"]))
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# TODO: clean code
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
save_state(
@ -290,9 +301,9 @@ if __name__ == "__main__":
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob,
min_level_db, ref_level_db, power, n_iter,
win_length, hop_length, preemphasis)
dv3, sent, replace_pronounciation_prob, min_level_db,
ref_level_db, power, n_iter, win_length, hop_length,
preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
@ -314,16 +325,7 @@ if __name__ == "__main__":
# save checkpoint
if global_step % save_interval == 0:
dg.save_dygraph(
dv3.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
dg.save_dygraph(
optim.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
io.save_latest_parameters(ckpt_dir, global_step, dv3, optim)
io.save_latest_checkpoint(ckpt_dir, global_step)
global_step += 1
# epoch report
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
epoch_loss = 0.

View File

@ -28,7 +28,7 @@ Train the model using train.py. For help on usage, try `python train.py --help`.
```text
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
[--device DEVICE] [--resume RESUME]
[--device DEVICE] [--checkpoint CHECKPOINT]
Train a WaveNet model with LJSpeech.
@ -38,12 +38,12 @@ optional arguments:
--config CONFIG path of the config file.
--output OUTPUT path to save results.
--device DEVICE device to use.
--resume RESUME checkpoint to resume from.
--checkpoint CHECKPOINT checkpoint to resume from.
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
@ -51,6 +51,8 @@ optional arguments:
└── log # tensorboard log
```
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:

View File

@ -27,7 +27,7 @@ from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, eval_model, save_checkpoint
from utils import make_output_tree, valid_model, eval_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -87,7 +87,8 @@ if __name__ == "__main__":
batch_size=1,
sampler=SequentialSampler(ljspeech_valid))
make_output_tree(args.output)
if not os.path.exists(args.output):
os.makedirs(args.output)
if args.device == -1:
place = fluid.CPUPlace()

View File

@ -16,7 +16,7 @@ from __future__ import division
import os
import ruamel.yaml
import argparse
from tqdm import tqdm
import tqdm
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
@ -24,13 +24,14 @@ import paddle.fluid.dygraph as dg
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, save_checkpoint
from utils import make_output_tree, valid_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a wavenet model with LJSpeech.")
description="Train a WaveNet model with LJSpeech.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
@ -42,12 +43,16 @@ if __name__ == "__main__":
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument(
"--resume", type=str, help="checkpoint to resume from.")
"--checkpoint", type=str, help="checkpoint to resume from.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
@ -126,14 +131,6 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
if args.resume:
model_dict, optim_dict = dg.load_dygraph(args.resume)
print("Loading from {}.pdparams".format(args.resume))
model.set_dict(model_dict)
if optim_dict:
optim.set_dict(optim_dict)
print("Loading from {}.pdopt".format(args.resume))
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
@ -150,10 +147,24 @@ if __name__ == "__main__":
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
global_step = 1
# load parameters and optimizer, and opdate iterations done sofar
io.load_parameters(
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
if args.checkpoint is not None:
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
else:
iteration = io.load_latest_checkpoint(checkpoint_dir)
global_step = iteration + 1
iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iterations:
epoch_loss = 0.
for i, batch in tqdm(enumerate(train_loader)):
print(global_step)
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(train_loader))
batch = next(iterator)
audio_clips, mel_specs, audio_starts = batch
model.train()
@ -162,21 +173,22 @@ if __name__ == "__main__":
loss_var.backward()
loss_np = loss_var.numpy()
epoch_loss += loss_np[0]
writer.add_scalar("loss", loss_np[0], global_step)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
optim.minimize(loss_var, grad_clip=clipper)
optim.clear_gradients()
print("loss: {:<8.6f}".format(loss_np[0]))
print("global_step: {}\tloss: {:<8.6f}".format(global_step,
loss_np[0]))
if global_step % snap_interval == 0:
valid_model(model, valid_loader, writer, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
io.save_latest_parameters(checkpoint_dir, global_step, model,
optim)
io.save_latest_checkpoint(checkpoint_dir, global_step)
global_step += 1

View File

@ -59,10 +59,3 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
wav_np = wav_var.numpy()[0]
sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
checkpoint_path = os.path.join(checkpoint_dir,
"step_{:09d}".format(global_step))
dg.save_dygraph(model.state_dict(), checkpoint_path)
dg.save_dygraph(optim.state_dict(), checkpoint_path)