update save & load for deep voicde 3, wavenet and clarinet, remove the concept of epoch in training
This commit is contained in:
parent
64790853e5
commit
776743530a
|
@ -28,7 +28,7 @@ Train the model using train.py, follow the usage displayed by `python train.py -
|
|||
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
|
||||
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
|
||||
[--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET]
|
||||
|
||||
train a ClariNet model with LJspeech and a trained WaveNet model.
|
||||
|
||||
|
@ -38,13 +38,13 @@ optional arguments:
|
|||
--device DEVICE device to use.
|
||||
--output OUTPUT path to save student.
|
||||
--data DATA path of LJspeech dataset.
|
||||
--resume RESUME checkpoint to load from.
|
||||
--checkpoint CHECKPOINT checkpoint to load from.
|
||||
--wavenet WAVENET wavenet checkpoint to use.
|
||||
```
|
||||
|
||||
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
||||
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
||||
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
||||
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
|
||||
|
||||
```text
|
||||
|
@ -53,6 +53,8 @@ optional arguments:
|
|||
└── log # tensorboard log
|
||||
```
|
||||
|
||||
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
|
||||
|
||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
|||
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
||||
from parakeet.utils.layer_tools import summary, freeze
|
||||
|
||||
from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model
|
||||
from utils import valid_model, eval_model, load_model
|
||||
sys.path.append("../wavenet")
|
||||
from data import LJSpeechMetaData, Transform, DataCollector
|
||||
|
||||
|
|
|
@ -30,14 +30,15 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet
|
|||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
||||
from parakeet.utils.layer_tools import summary, freeze
|
||||
from parakeet.utils import io
|
||||
|
||||
from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet
|
||||
from utils import make_output_tree, valid_model, load_wavenet
|
||||
sys.path.append("../wavenet")
|
||||
from data import LJSpeechMetaData, Transform, DataCollector
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="train a clarinet model with LJspeech and a trained wavenet model."
|
||||
description="train a ClariNet model with LJspeech and a trained WaveNet model."
|
||||
)
|
||||
parser.add_argument("--config", type=str, help="path of the config file.")
|
||||
parser.add_argument(
|
||||
|
@ -48,13 +49,18 @@ if __name__ == "__main__":
|
|||
default="experiment",
|
||||
help="path to save student.")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
|
||||
parser.add_argument("--resume", type=str, help="checkpoint to load from.")
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, help="checkpoint to load from.")
|
||||
parser.add_argument(
|
||||
"--wavenet", type=str, help="wavenet checkpoint to use.")
|
||||
args = parser.parse_args()
|
||||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
print("Command Line args: ")
|
||||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
||||
ljspeech_meta = LJSpeechMetaData(args.data)
|
||||
|
||||
data_config = config["data"]
|
||||
|
@ -154,12 +160,38 @@ if __name__ == "__main__":
|
|||
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
||||
gradiant_max_norm)
|
||||
|
||||
assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
|
||||
if args.wavenet:
|
||||
# train
|
||||
max_iterations = train_config["max_iterations"]
|
||||
checkpoint_interval = train_config["checkpoint_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
state_dir = os.path.join(args.output, "states")
|
||||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
# load wavenet/checkpoint, determine iterations done
|
||||
if args.checkpoint is not None:
|
||||
iteration = int(os.path.basename(args.checkpoint).split('-')[-1])
|
||||
else:
|
||||
iteration = io.load_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
if iteration == 0 and args.wavenet is None:
|
||||
raise Exception(
|
||||
"you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
|
||||
)
|
||||
|
||||
if args.wavenet is not None and iteration > 0:
|
||||
if args.checkpoint is None:
|
||||
print("Resume training, --wavenet ignored")
|
||||
else:
|
||||
print("--checkpoint provided, --wavenet ignored")
|
||||
|
||||
if args.wavenet is not None and iteration == 0:
|
||||
load_wavenet(model, args.wavenet)
|
||||
|
||||
if args.resume:
|
||||
load_checkpoint(model, optim, args.resume)
|
||||
# it may overwrite the wavenet loaded
|
||||
io.load_parameters(
|
||||
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
|
||||
|
||||
# loader
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
|
@ -170,21 +202,16 @@ if __name__ == "__main__":
|
|||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
# train
|
||||
max_iterations = train_config["max_iterations"]
|
||||
checkpoint_interval = train_config["checkpoint_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
state_dir = os.path.join(args.output, "states")
|
||||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
# training loop
|
||||
global_step = 1
|
||||
global_epoch = 1
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm(train_loader))
|
||||
while global_step < max_iterations:
|
||||
epoch_loss = 0.
|
||||
for j, batch in tqdm(enumerate(train_loader), desc="[train]"):
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm(train_loader))
|
||||
batch = next(iterator)
|
||||
|
||||
audios, mels, audio_starts = batch
|
||||
model.train()
|
||||
loss_dict = model(
|
||||
|
@ -200,7 +227,6 @@ if __name__ == "__main__":
|
|||
l = loss_dict["loss"]
|
||||
step_loss = l.numpy()[0]
|
||||
print("[train] loss: {:<8.6f}".format(step_loss))
|
||||
epoch_loss += step_loss
|
||||
|
||||
l.backward()
|
||||
optim.minimize(l, grad_clip=clipper)
|
||||
|
@ -211,11 +237,8 @@ if __name__ == "__main__":
|
|||
valid_model(model, valid_loader, state_dir, global_step,
|
||||
sample_rate)
|
||||
if global_step % checkpoint_interval == 0:
|
||||
save_checkpoint(model, optim, checkpoint_dir, global_step)
|
||||
io.save_latest_parameters(checkpoint_dir, global_step, model,
|
||||
optim)
|
||||
io.save_latest_checkpoint(checkpoint_dir, global_step)
|
||||
|
||||
global_step += 1
|
||||
|
||||
# epoch loss
|
||||
average_loss = epoch_loss / j
|
||||
writer.add_scalar("average_loss", average_loss, global_epoch)
|
||||
global_epoch += 1
|
||||
|
|
|
@ -35,26 +35,23 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
|
|||
Train the model using train.py, follow the usage displayed by `python train.py --help`.
|
||||
|
||||
```text
|
||||
usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE]
|
||||
usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT]
|
||||
[-o OUTPUT] [-g DEVICE]
|
||||
|
||||
Train a Deep Voice 3 model with LJSpeech dataset.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-c CONFIG, --config CONFIG
|
||||
experimrnt config
|
||||
-c CONFIG, --config CONFIG experimrnt config
|
||||
-s DATA, --data DATA The path of the LJSpeech dataset.
|
||||
-r RESUME, --resume RESUME
|
||||
checkpoint to load
|
||||
-o OUTPUT, --output OUTPUT
|
||||
The directory to save result.
|
||||
-g DEVICE, --device DEVICE
|
||||
device to use
|
||||
--checkpoint CHECKPOINT checkpoint to load
|
||||
-o OUTPUT, --output OUTPUT The directory to save result.
|
||||
-g DEVICE, --device DEVICE device to use
|
||||
```
|
||||
|
||||
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
||||
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
||||
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
||||
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
|
||||
|
||||
```text
|
||||
|
@ -67,6 +64,8 @@ optional arguments:
|
|||
└── waveform # waveform (.wav files)
|
||||
```
|
||||
|
||||
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
|
||||
|
||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||
|
||||
Example script:
|
||||
|
|
|
@ -83,7 +83,7 @@ lr_scheduler:
|
|||
|
||||
train:
|
||||
batch_size: 16
|
||||
epochs: 2000
|
||||
max_iteration: 2000000
|
||||
|
||||
snap_interval: 1000
|
||||
eval_interval: 10000
|
||||
|
|
|
@ -25,8 +25,9 @@ import paddle.fluid.dygraph as dg
|
|||
from tensorboardX import SummaryWriter
|
||||
|
||||
from parakeet.g2p import en
|
||||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.modules.weight_norm import WeightNormWrapper
|
||||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.utils.io import load_parameters
|
||||
|
||||
from utils import make_model, eval_model, plot_alignment
|
||||
|
||||
|
@ -44,6 +45,10 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
print("Command Line Args: ")
|
||||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
|
|
|
@ -17,6 +17,8 @@ import os
|
|||
import argparse
|
||||
import ruamel.yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("agg")
|
||||
from matplotlib import cm
|
||||
import matplotlib.pyplot as plt
|
||||
import tqdm
|
||||
|
@ -35,13 +37,14 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler,
|
|||
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
|
||||
from parakeet.models.deepvoice3.loss import TTSLoss
|
||||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.utils import io
|
||||
|
||||
from data import LJSpeechMetaData, DataCollector, Transform
|
||||
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a deepvoice 3 model with LJSpeech dataset.")
|
||||
description="Train a Deep Voice 3 model with LJSpeech dataset.")
|
||||
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
|
@ -49,7 +52,7 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
default="/workspace/datasets/LJSpeech-1.1/",
|
||||
help="The path of the LJSpeech dataset.")
|
||||
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
|
||||
parser.add_argument("--checkpoint", type=str, help="checkpoint to load")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
|
@ -62,6 +65,10 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
print("Command Line Args: ")
|
||||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
||||
# =========================dataset=========================
|
||||
# construct meta data
|
||||
data_root = args.data
|
||||
|
@ -151,6 +158,7 @@ if __name__ == "__main__":
|
|||
query_position_rate, key_position_rate, window_backward,
|
||||
window_ahead, key_projection, value_projection, downsample_factor,
|
||||
linear_dim, use_decoder_states, converter_channels, dropout)
|
||||
summary(dv3)
|
||||
|
||||
# =========================loss=========================
|
||||
loss_config = config["loss"]
|
||||
|
@ -195,7 +203,6 @@ if __name__ == "__main__":
|
|||
n_iter = synthesis_config["n_iter"]
|
||||
|
||||
# =========================link(dataloader, paddle)=========================
|
||||
# CAUTION: it does not return a DataLoader
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
loader.set_batch_generator(ljspeech_loader, places=place)
|
||||
|
@ -208,24 +215,29 @@ if __name__ == "__main__":
|
|||
make_output_tree(output_dir)
|
||||
writer = SummaryWriter(logdir=log_dir)
|
||||
|
||||
# load model parameters
|
||||
resume_path = args.resume
|
||||
if resume_path is not None:
|
||||
state, _ = dg.load_dygraph(args.resume)
|
||||
dv3.set_dict(state)
|
||||
# load parameters and optimizer, and opdate iterations done sofar
|
||||
io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint)
|
||||
if args.checkpoint is not None:
|
||||
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
|
||||
else:
|
||||
iteration = io.load_latest_checkpoint(ckpt_dir)
|
||||
|
||||
# =========================train=========================
|
||||
epoch = train_config["epochs"]
|
||||
max_iter = train_config["max_iteration"]
|
||||
snap_interval = train_config["snap_interval"]
|
||||
save_interval = train_config["save_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
|
||||
global_step = 1
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm.tqdm(loader))
|
||||
while global_step <= max_iter:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm.tqdm(loader))
|
||||
batch = next(iterator)
|
||||
|
||||
for j in range(1, 1 + epoch):
|
||||
epoch_loss = 0.
|
||||
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
|
||||
dv3.train() # CAUTION: don't forget to switch to train
|
||||
dv3.train()
|
||||
(text_sequences, text_lengths, text_positions, mel_specs,
|
||||
lin_specs, frames, decoder_positions, done_flags) = batch
|
||||
downsampled_mel_specs = F.strided_slice(
|
||||
|
@ -238,26 +250,25 @@ if __name__ == "__main__":
|
|||
text_sequences, text_positions, text_lengths, None,
|
||||
downsampled_mel_specs, decoder_positions)
|
||||
|
||||
losses = criterion(mel_outputs, linear_outputs, done,
|
||||
alignments, downsampled_mel_specs,
|
||||
lin_specs, done_flags, text_lengths, frames)
|
||||
losses = criterion(mel_outputs, linear_outputs, done, alignments,
|
||||
downsampled_mel_specs, lin_specs, done_flags,
|
||||
text_lengths, frames)
|
||||
l = losses["loss"]
|
||||
l.backward()
|
||||
# record learning rate before updating
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
optim._learning_rate.step().numpy(), global_step)
|
||||
optim.minimize(l, grad_clip=gradient_clipper)
|
||||
optim.clear_gradients()
|
||||
|
||||
# ==================all kinds of tedious things=================
|
||||
# record step loss into tensorboard
|
||||
epoch_loss += l.numpy()[0]
|
||||
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
|
||||
tqdm.tqdm.write("global_step: {}\tloss: {}".format(
|
||||
global_step, step_loss["loss"]))
|
||||
for k, v in step_loss.items():
|
||||
writer.add_scalar(k, v, global_step)
|
||||
|
||||
# TODO: clean code
|
||||
# train state saving, the first sentence in the batch
|
||||
if global_step % snap_interval == 0:
|
||||
save_state(
|
||||
|
@ -290,9 +301,9 @@ if __name__ == "__main__":
|
|||
]
|
||||
for idx, sent in enumerate(sentences):
|
||||
wav, attn = eval_model(
|
||||
dv3, sent, replace_pronounciation_prob,
|
||||
min_level_db, ref_level_db, power, n_iter,
|
||||
win_length, hop_length, preemphasis)
|
||||
dv3, sent, replace_pronounciation_prob, min_level_db,
|
||||
ref_level_db, power, n_iter, win_length, hop_length,
|
||||
preemphasis)
|
||||
wav_path = os.path.join(
|
||||
state_dir, "waveform",
|
||||
"eval_sample_{:09d}.wav".format(global_step))
|
||||
|
@ -314,16 +325,7 @@ if __name__ == "__main__":
|
|||
|
||||
# save checkpoint
|
||||
if global_step % save_interval == 0:
|
||||
dg.save_dygraph(
|
||||
dv3.state_dict(),
|
||||
os.path.join(ckpt_dir,
|
||||
"model_step_{}".format(global_step)))
|
||||
dg.save_dygraph(
|
||||
optim.state_dict(),
|
||||
os.path.join(ckpt_dir,
|
||||
"model_step_{}".format(global_step)))
|
||||
io.save_latest_parameters(ckpt_dir, global_step, dv3, optim)
|
||||
io.save_latest_checkpoint(ckpt_dir, global_step)
|
||||
|
||||
global_step += 1
|
||||
# epoch report
|
||||
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
|
||||
epoch_loss = 0.
|
||||
|
|
|
@ -28,7 +28,7 @@ Train the model using train.py. For help on usage, try `python train.py --help`.
|
|||
|
||||
```text
|
||||
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
|
||||
[--device DEVICE] [--resume RESUME]
|
||||
[--device DEVICE] [--checkpoint CHECKPOINT]
|
||||
|
||||
Train a WaveNet model with LJSpeech.
|
||||
|
||||
|
@ -38,12 +38,12 @@ optional arguments:
|
|||
--config CONFIG path of the config file.
|
||||
--output OUTPUT path to save results.
|
||||
--device DEVICE device to use.
|
||||
--resume RESUME checkpoint to resume from.
|
||||
--checkpoint CHECKPOINT checkpoint to resume from.
|
||||
```
|
||||
|
||||
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
||||
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
|
||||
- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
|
||||
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
|
||||
|
||||
```text
|
||||
|
@ -51,6 +51,8 @@ optional arguments:
|
|||
└── log # tensorboard log
|
||||
```
|
||||
|
||||
If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training.
|
||||
|
||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||
|
||||
Example script:
|
||||
|
|
|
@ -27,7 +27,7 @@ from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
|||
from parakeet.utils.layer_tools import summary
|
||||
|
||||
from data import LJSpeechMetaData, Transform, DataCollector
|
||||
from utils import make_output_tree, valid_model, eval_model, save_checkpoint
|
||||
from utils import make_output_tree, valid_model, eval_model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -87,7 +87,8 @@ if __name__ == "__main__":
|
|||
batch_size=1,
|
||||
sampler=SequentialSampler(ljspeech_valid))
|
||||
|
||||
make_output_tree(args.output)
|
||||
if not os.path.exists(args.output):
|
||||
os.makedirs(args.output)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
|
|
|
@ -16,7 +16,7 @@ from __future__ import division
|
|||
import os
|
||||
import ruamel.yaml
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import tqdm
|
||||
from tensorboardX import SummaryWriter
|
||||
from paddle import fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
@ -24,13 +24,14 @@ import paddle.fluid.dygraph as dg
|
|||
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
|
||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
||||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.utils import io
|
||||
|
||||
from data import LJSpeechMetaData, Transform, DataCollector
|
||||
from utils import make_output_tree, valid_model, save_checkpoint
|
||||
from utils import make_output_tree, valid_model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a wavenet model with LJSpeech.")
|
||||
description="Train a WaveNet model with LJSpeech.")
|
||||
parser.add_argument(
|
||||
"--data", type=str, help="path of the LJspeech dataset.")
|
||||
parser.add_argument("--config", type=str, help="path of the config file.")
|
||||
|
@ -42,12 +43,16 @@ if __name__ == "__main__":
|
|||
parser.add_argument(
|
||||
"--device", type=int, default=-1, help="device to use.")
|
||||
parser.add_argument(
|
||||
"--resume", type=str, help="checkpoint to resume from.")
|
||||
"--checkpoint", type=str, help="checkpoint to resume from.")
|
||||
|
||||
args = parser.parse_args()
|
||||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
print("Command Line Args: ")
|
||||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
||||
ljspeech_meta = LJSpeechMetaData(args.data)
|
||||
|
||||
data_config = config["data"]
|
||||
|
@ -126,14 +131,6 @@ if __name__ == "__main__":
|
|||
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
||||
gradiant_max_norm)
|
||||
|
||||
if args.resume:
|
||||
model_dict, optim_dict = dg.load_dygraph(args.resume)
|
||||
print("Loading from {}.pdparams".format(args.resume))
|
||||
model.set_dict(model_dict)
|
||||
if optim_dict:
|
||||
optim.set_dict(optim_dict)
|
||||
print("Loading from {}.pdopt".format(args.resume))
|
||||
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
|
@ -150,10 +147,24 @@ if __name__ == "__main__":
|
|||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
global_step = 1
|
||||
# load parameters and optimizer, and opdate iterations done sofar
|
||||
io.load_parameters(
|
||||
checkpoint_dir, 0, model, optim, file_path=args.checkpoint)
|
||||
if args.checkpoint is not None:
|
||||
iteration = int(os.path.basename(args.checkpoint).split("-")[-1])
|
||||
else:
|
||||
iteration = io.load_latest_checkpoint(checkpoint_dir)
|
||||
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
while global_step <= max_iterations:
|
||||
epoch_loss = 0.
|
||||
for i, batch in tqdm(enumerate(train_loader)):
|
||||
print(global_step)
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
batch = next(iterator)
|
||||
|
||||
audio_clips, mel_specs, audio_starts = batch
|
||||
|
||||
model.train()
|
||||
|
@ -162,21 +173,22 @@ if __name__ == "__main__":
|
|||
loss_var.backward()
|
||||
loss_np = loss_var.numpy()
|
||||
|
||||
epoch_loss += loss_np[0]
|
||||
|
||||
writer.add_scalar("loss", loss_np[0], global_step)
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy()[0],
|
||||
global_step)
|
||||
optim.minimize(loss_var, grad_clip=clipper)
|
||||
optim.clear_gradients()
|
||||
print("loss: {:<8.6f}".format(loss_np[0]))
|
||||
print("global_step: {}\tloss: {:<8.6f}".format(global_step,
|
||||
loss_np[0]))
|
||||
|
||||
if global_step % snap_interval == 0:
|
||||
valid_model(model, valid_loader, writer, global_step,
|
||||
sample_rate)
|
||||
|
||||
if global_step % checkpoint_interval == 0:
|
||||
save_checkpoint(model, optim, checkpoint_dir, global_step)
|
||||
io.save_latest_parameters(checkpoint_dir, global_step, model,
|
||||
optim)
|
||||
io.save_latest_checkpoint(checkpoint_dir, global_step)
|
||||
|
||||
global_step += 1
|
||||
|
|
|
@ -59,10 +59,3 @@ def eval_model(model, valid_loader, output_dir, sample_rate):
|
|||
wav_np = wav_var.numpy()[0]
|
||||
sf.write(path, wav_np, samplerate=sample_rate)
|
||||
print("generated {}".format(path))
|
||||
|
||||
|
||||
def save_checkpoint(model, optim, checkpoint_dir, global_step):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"step_{:09d}".format(global_step))
|
||||
dg.save_dygraph(model.state_dict(), checkpoint_path)
|
||||
dg.save_dygraph(optim.state_dict(), checkpoint_path)
|
||||
|
|
Loading…
Reference in New Issue