Merge branch 'master' into 'master'

update save & load for deep voicde 3, wavenet and clarinet, remove the concept of epoch in training

See merge request !48
This commit is contained in:
liuyibing01 2020-03-26 18:44:46 +08:00
commit a3683ac297
13 changed files with 602 additions and 404 deletions

View File

@ -22,47 +22,71 @@ tar xjvf LJSpeech-1.1.tar.bz2
└── utils.py utility functions
```
## Saving & Loading
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
1. `output` is the directory for saving results.
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
```text
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
├── states/ # audio files generated at validation and other possible outputs
├── log/ # tensorboard log
└── synthesis/ # synthesized audio files and other possible outputs
```
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
## Train
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
[--checkpoint CHECKPOINT | --iteration ITERATION]
[--wavenet WAVENET]
output
train a ClariNet model with LJspeech and a trained WaveNet model.
Train a ClariNet model with LJspeech and a trained WaveNet model.
positional arguments:
output path to save experiment results
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--device DEVICE device to use.
--output OUTPUT path to save student.
--data DATA path of LJspeech dataset.
--resume RESUME checkpoint to load from.
--wavenet WAVENET wavenet checkpoint to use.
```
--config CONFIG path of the config file
--device DEVICE device to use
--data DATA path of LJspeech dataset
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
--wavenet WAVENET wavenet checkpoint to use
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
├── states # audio files generated at validation
└── log # tensorboard log
```
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains `metadata.txt`).
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save results, all result are saved in this directory.
Before you start training a ClariNet model, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained model.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
- `--wavenet` is the path of the wavenet checkpoint to load.
When you start training a ClariNet model without loading form a ClariNet checkpoint, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained wavenet model.
Example script:
```bash
python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher
python train.py
--config=./configs/clarinet_ljspeech.yaml
--data=./LJSpeech-1.1/
--device=0
--wavenet="wavenet-step-2000000"
experiment
```
You can monitor training log via tensorboard, using the script below.
@ -75,29 +99,50 @@ tensorboard --logdir=.
## Synthesis
```text
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
checkpoint output
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
train a ClariNet model with LJspeech and a trained WaveNet model.
Synthesize audio files from mel spectrogram in the validation set.
positional arguments:
checkpoint checkpoint to load from.
output path to save student.
output path to save the synthesized audio
optional arguments:
-h, --help show this help message and exit
--config CONFIG path of the config file.
--config CONFIG path of the config file
--device DEVICE device to use.
--data DATA path of LJspeech dataset.
--data DATA path of LJspeech dataset
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `checkpoint` is the checkpoint to load.
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `--checkpoint` is the checkpoint to load.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
Example script:
```bash
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--iteration=500000 \
experiment
```
or
```bash
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--checkpoint="experiment/checkpoints/step-500000" \
experiment
```

View File

@ -26,29 +26,41 @@ from tensorboardX import SummaryWriter
import paddle.fluid.dygraph as dg
from paddle import fluid
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model
from utils import eval_model
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="synthesize audio files from mel spectrogram in the validation set."
description="Synthesize audio files from mel spectrogram in the validation set."
)
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"checkpoint", type=str, help="checkpoint to load from.")
parser.add_argument(
"output", type=str, default="experiment", help="path to save student.")
"output",
type=str,
default="experiment",
help="path to save the synthesized audio")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
@ -136,17 +148,32 @@ if __name__ == "__main__":
model = Clarinet(upsample_net, teacher, student, stft,
student_log_scale_min, lmd)
summary(model)
load_model(model, args.checkpoint)
# loader
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
# load parameters
if args.checkpoint is not None:
# load from args.checkpoint
iteration = io.load_parameters(
model, checkpoint_path=args.checkpoint)
else:
# load from "args.output/checkpoints"
checkpoint_dir = os.path.join(args.output, "checkpoints")
iteration = io.load_parameters(
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
assert iteration > 0, "A trained checkpoint is needed."
# make generation fast
for sublayer in model.sublayers():
if isinstance(sublayer, WeightNormWrapper):
sublayer.remove_weight_norm()
# data loader
valid_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
if not os.path.exists(args.output):
os.makedirs(args.output)
eval_model(model, valid_loader, args.output, sample_rate)
# the directory to save audio files
synthesis_dir = os.path.join(args.output, "synthesis")
if not os.path.exists(synthesis_dir):
os.makedirs(synthesis_dir)
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)

View File

@ -30,31 +30,46 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
from parakeet.utils.layer_tools import summary, freeze
from parakeet.utils import io
from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet
from utils import make_output_tree, eval_model, load_wavenet
# import dataset from wavenet
sys.path.append("../wavenet")
from data import LJSpeechMetaData, Transform, DataCollector
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="train a clarinet model with LJspeech and a trained wavenet model."
description="Train a ClariNet model with LJspeech and a trained WaveNet model."
)
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--device", type=int, default=-1, help="device to use")
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
"--wavenet", type=str, help="wavenet checkpoint to use")
parser.add_argument(
"--output",
"output",
type=str,
default="experiment",
help="path to save student.")
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
parser.add_argument("--resume", type=str, help="checkpoint to load from.")
parser.add_argument(
"--wavenet", type=str, help="wavenet checkpoint to use.")
help="path to save experiment results")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
@ -154,12 +169,28 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
if args.wavenet:
load_wavenet(model, args.wavenet)
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
if args.resume:
load_checkpoint(model, optim, args.resume)
if args.checkpoint is not None:
iteration = io.load_parameters(
model, optim, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
model,
optim,
checkpoint_dir=checkpoint_dir,
iteration=args.iteration)
if iteration == 0:
assert args.wavenet is not None, "When training afresh, a trained wavenet model should be provided."
load_wavenet(model, args.wavenet)
# loader
train_loader = fluid.io.DataLoader.from_generator(
@ -170,21 +201,16 @@ if __name__ == "__main__":
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
# train
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
state_dir = os.path.join(args.output, "states")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
# training loop
global_step = 1
global_epoch = 1
while global_step < max_iterations:
epoch_loss = 0.
for j, batch in tqdm(enumerate(train_loader), desc="[train]"):
global_step = iteration + 1
iterator = iter(tqdm(train_loader))
while global_step <= max_iterations:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(train_loader))
batch = next(iterator)
audios, mels, audio_starts = batch
model.train()
loss_dict = model(
@ -199,8 +225,8 @@ if __name__ == "__main__":
l = loss_dict["loss"]
step_loss = l.numpy()[0]
print("[train] loss: {:<8.6f}".format(step_loss))
epoch_loss += step_loss
print("[train] global_step: {} loss: {:<8.6f}".format(global_step,
step_loss))
l.backward()
optim.minimize(l, grad_clip=clipper)
@ -208,14 +234,9 @@ if __name__ == "__main__":
if global_step % eval_interval == 0:
# evaluate on valid dataset
valid_model(model, valid_loader, state_dir, global_step,
eval_model(model, valid_loader, state_dir, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
io.save_parameters(checkpoint_dir, global_step, model, optim)
global_step += 1
# epoch loss
average_loss = epoch_loss / j
writer.add_scalar("average_loss", average_loss, global_epoch)
global_epoch += 1

View File

@ -32,12 +32,12 @@ def make_output_tree(output_dir):
os.makedirs(state_dir)
def valid_model(model, valid_loader, output_dir, global_step, sample_rate):
def eval_model(model, valid_loader, output_dir, iteration, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir,
"step_{}_sentence_{}.wav".format(global_step, i))
"sentence_{}_step_{}.wav".format(i, iteration))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
@ -45,42 +45,6 @@ def valid_model(model, valid_loader, output_dir, global_step, sample_rate):
print("generated {}".format(path))
def eval_model(model, valid_loader, output_dir, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
path = os.path.join(checkpoint_dir, "step_{}".format(global_step))
dg.save_dygraph(model.state_dict(), path)
print("saving model to {}".format(path + ".pdparams"))
if optim:
dg.save_dygraph(optim.state_dict(), path)
print("saving optimizer to {}".format(path + ".pdopt"))
def load_model(model, path):
model_dict, _ = dg.load_dygraph(path)
model.set_dict(model_dict)
print("loaded model from {}.pdparams".format(path))
def load_checkpoint(model, optim, path):
model_dict, optim_dict = dg.load_dygraph(path)
model.set_dict(model_dict)
print("loaded model from {}.pdparams".format(path))
if optim_dict:
optim.set_dict(optim_dict)
print("loaded optimizer from {}.pdparams".format(path))
def load_wavenet(model, path):
wavenet_dict, _ = dg.load_dygraph(path)
encoder_dict = OrderedDict()

View File

@ -30,32 +30,55 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
└── utils.py utility functions
```
## Saving & Loading
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
1. `output` is the directory for saving results.
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
```text
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
├── states/ # audio files generated at validation and other possible outputs
├── log/ # tensorboard log
└── synthesis/ # synthesized audio files and other possible outputs
```
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
## Train
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE]
usage: train.py [-h] [--config CONFIG] [--data DATA] [--device DEVICE]
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
Train a Deep Voice 3 model with LJSpeech dataset.
positional arguments:
output path to save results
optional arguments:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG
experimrnt config
-s DATA, --data DATA The path of the LJSpeech dataset.
-r RESUME, --resume RESUME
checkpoint to load
-o OUTPUT, --output OUTPUT
The directory to save result.
-g DEVICE, --device DEVICE
device to use
--config CONFIG experimrnt config
--data DATA The path of the LJSpeech dataset.
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from.
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
- `output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
@ -67,12 +90,14 @@ optional arguments:
└── waveform # waveform (.wav files)
```
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
Example script:
```bash
python train.py --config=configs/ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
python train.py \
--config=configs/ljspeech.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
experiment
```
You can monitor training log via tensorboard, using the script below.
@ -84,31 +109,50 @@ tensorboard --logdir=.
## Synthesis
```text
usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE]
[--checkpoint CHECKPOINT | --iteration ITERATION]
text output
Synthsize waveform from a checkpoint.
Synthsize waveform with a checkpoint.
positional arguments:
checkpoint checkpoint to load.
text text file to synthesize
output_path path to save results
output path to save synthesized audio
optional arguments:
-h, --help show this help message and exit
-c CONFIG, --config CONFIG
experiment config.
-g DEVICE, --device DEVICE
device to use
--config CONFIG experiment config
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
- `checkpoint` is the checkpoint to load.
- `text`is the text file to synthesize.
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
- `text`is the text file to synthesize.
- `output` is the directory to save results. The generated audio files (`*.wav`) and attention plots (*.png) for are save in `synthesis/` in ouput directory.
Example script:
```bash
python synthesis.py --config=configs/ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated
python synthesis.py \
--config=configs/ljspeech.yaml \
--device=0 \
--checkpoint="experiment/checkpoints/model_step_005000000" \
sentences.txt experiment
```
or
```bash
python synthesis.py \
--config=configs/ljspeech.yaml \
--device=0 \
--iteration=005000000 \
sentences.txt experiment
```

View File

@ -83,7 +83,7 @@ lr_scheduler:
train:
batch_size: 16
epochs: 2000
max_iteration: 2000000
snap_interval: 1000
eval_interval: 10000

View File

@ -25,25 +25,37 @@ import paddle.fluid.dygraph as dg
from tensorboardX import SummaryWriter
from parakeet.g2p import en
from parakeet.utils.layer_tools import summary
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from utils import make_model, eval_model, plot_alignment
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Synthsize waveform with a checkpoint.")
parser.add_argument("-c", "--config", type=str, help="experiment config.")
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument("--config", type=str, help="experiment config")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument("text", type=str, help="text file to synthesize")
parser.add_argument("output_path", type=str, help="path to save results")
parser.add_argument(
"-g", "--device", type=int, default=-1, help="device to use")
"output", type=str, help="path to save synthesized audio")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
if args.device == -1:
place = fluid.CPUPlace()
else:
@ -98,8 +110,14 @@ if __name__ == "__main__":
linear_dim, use_decoder_states, converter_channels, dropout)
summary(dv3)
state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state)
checkpoint_dir = os.path.join(args.output, "checkpoints")
if args.checkpoint is not None:
iteration = io.load_parameters(
dv3, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
dv3, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
@ -107,9 +125,6 @@ if __name__ == "__main__":
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
transform_config = config["transform"]
c = transform_config["replace_pronunciation_prob"]
sample_rate = transform_config["sample_rate"]
@ -123,6 +138,10 @@ if __name__ == "__main__":
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
synthesis_dir = os.path.join(args.output, "synthesis")
if not os.path.exists(synthesis_dir):
os.makedirs(synthesis_dir)
with open(args.text, "rt", encoding="utf-8") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
@ -134,7 +153,9 @@ if __name__ == "__main__":
preemphasis)
plot_alignment(
attn,
os.path.join(args.output_path, "test_{}.png".format(idx)))
os.path.join(synthesis_dir,
"test_{}_step_{}.png".format(idx, iteration)))
sf.write(
os.path.join(args.output_path, "test_{}.wav".format(idx)),
os.path.join(synthesis_dir,
"test_{}_step{}.wav".format(idx, iteration)),
wav, sample_rate)

View File

@ -17,6 +17,8 @@ import os
import argparse
import ruamel.yaml
import numpy as np
import matplotlib
matplotlib.use("agg")
from matplotlib import cm
import matplotlib.pyplot as plt
import tqdm
@ -35,33 +37,40 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler,
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
description="Train a Deep Voice 3 model with LJSpeech dataset.")
parser.add_argument("--config", type=str, help="experimrnt config")
parser.add_argument(
"-s",
"--data",
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument(
"-g", "--device", type=int, default=-1, help="device to use")
"output", type=str, default="experiment", help="path to save results")
args, _ = parser.parse_known_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
# =========================dataset=========================
# construct meta data
data_root = args.data
@ -151,6 +160,7 @@ if __name__ == "__main__":
query_position_rate, key_position_rate, window_backward,
window_ahead, key_projection, value_projection, downsample_factor,
linear_dim, use_decoder_states, converter_channels, dropout)
summary(dv3)
# =========================loss=========================
loss_config = config["loss"]
@ -195,7 +205,6 @@ if __name__ == "__main__":
n_iter = synthesis_config["n_iter"]
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
@ -208,24 +217,30 @@ if __name__ == "__main__":
make_output_tree(output_dir)
writer = SummaryWriter(logdir=log_dir)
# load model parameters
resume_path = args.resume
if resume_path is not None:
state, _ = dg.load_dygraph(args.resume)
dv3.set_dict(state)
# load parameters and optimizer, and opdate iterations done sofar
if args.checkpoint is not None:
iteration = io.load_parameters(
dv3, optim, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
dv3, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration)
# =========================train=========================
epoch = train_config["epochs"]
max_iter = train_config["max_iteration"]
snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"]
global_step = 1
global_step = iteration + 1
iterator = iter(tqdm.tqdm(loader))
while global_step <= max_iter:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(loader))
batch = next(iterator)
for j in range(1, 1 + epoch):
epoch_loss = 0.
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
dv3.train() # CAUTION: don't forget to switch to train
dv3.train()
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
@ -238,26 +253,25 @@ if __name__ == "__main__":
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames)
losses = criterion(mel_outputs, linear_outputs, done, alignments,
downsampled_mel_specs, lin_specs, done_flags,
text_lengths, frames)
l = losses["loss"]
l.backward()
# record learning rate before updating
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy(),
global_step)
optim._learning_rate.step().numpy(), global_step)
optim.minimize(l, grad_clip=gradient_clipper)
optim.clear_gradients()
# ==================all kinds of tedious things=================
# record step loss into tensorboard
epoch_loss += l.numpy()[0]
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
tqdm.tqdm.write("global_step: {}\tloss: {}".format(
global_step, step_loss["loss"]))
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# TODO: clean code
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
save_state(
@ -290,9 +304,9 @@ if __name__ == "__main__":
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob,
min_level_db, ref_level_db, power, n_iter,
win_length, hop_length, preemphasis)
dv3, sent, replace_pronounciation_prob, min_level_db,
ref_level_db, power, n_iter, win_length, hop_length,
preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
@ -314,16 +328,6 @@ if __name__ == "__main__":
# save checkpoint
if global_step % save_interval == 0:
dg.save_dygraph(
dv3.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
dg.save_dygraph(
optim.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
io.save_parameters(ckpt_dir, global_step, dv3, optim)
global_step += 1
# epoch report
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
epoch_loss = 0.

View File

@ -22,41 +22,67 @@ tar xjvf LJSpeech-1.1.tar.bz2
└── utils.py utility functions
```
## Saving & Loading
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
1. `output` is the directory for saving results.
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
```text
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
├── states/ # audio files generated at validation and other possible outputs
├── log/ # tensorboard log
└── synthesis/ # synthesized audio files and other possible outputs
```
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
## Train
Train the model using train.py. For help on usage, try `python train.py --help`.
```text
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
[--device DEVICE] [--resume RESUME]
usage: train.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
Train a WaveNet model with LJSpeech.
positional arguments:
output path to save results
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--output OUTPUT path to save results.
--device DEVICE device to use.
--resume RESUME checkpoint to resume from.
--data DATA path of the LJspeech dataset
--config CONFIG path of the config file
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
└── log # tensorboard log
```
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the path of the checkpoint.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save results, all result are saved in this directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
Example script:
```bash
python train.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
python train.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
experiment
```
You can monitor training log via TensorBoard, using the script below.
@ -69,29 +95,50 @@ tensorboard --logdir=.
## Synthesis
```text
usage: synthesis.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
checkpoint output
[--checkpoint CHECKPOINT | --iteration ITERATION]
output
Synthesize valid data from LJspeech with a WaveNet model.
Synthesize valid data from LJspeech with a wavenet model.
positional arguments:
checkpoint checkpoint to load.
output path to save results.
output path to save the synthesized audio
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--device DEVICE device to use.
--data DATA path of the LJspeech dataset
--config CONFIG path of the config file
--device DEVICE device to use
--checkpoint CHECKPOINT checkpoint to resume from
--iteration ITERATION the iteration of the checkpoint to load from output directory
```
- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
- `checkpoint` is the checkpoint to load.
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
- `--checkpoint` is the checkpoint to load.
- `--iteration` is the iteration of the checkpoint to load from output directory.
- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory.
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
Example script:
```bash
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--checkpoint="experiment/checkpoints/step-1000000" \
experiment
```
or
```bash
python synthesis.py \
--config=./configs/wavenet_single_gaussian.yaml \
--data=./LJSpeech-1.1/ \
--device=0 \
--iteration=1000000 \
experiment
```

View File

@ -25,22 +25,31 @@ from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, eval_model, save_checkpoint
from utils import make_output_tree, valid_model, eval_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Synthesize valid data from LJspeech with a wavenet model.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
"--data", type=str, help="path of the LJspeech dataset")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument(
"output", type=str, default="experiment", help="path to save results.")
"output",
type=str,
default="experiment",
help="path to save the synthesized audio")
args = parser.parse_args()
with open(args.config, 'rt') as f:
@ -87,7 +96,8 @@ if __name__ == "__main__":
batch_size=1,
sampler=SequentialSampler(ljspeech_valid))
make_output_tree(args.output)
if not os.path.exists(args.output):
os.makedirs(args.output)
if args.device == -1:
place = fluid.CPUPlace()
@ -111,9 +121,15 @@ if __name__ == "__main__":
model = ConditionalWavenet(encoder, decoder)
summary(model)
model_dict, _ = dg.load_dygraph(args.checkpoint)
print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict)
# load model parameters
checkpoint_dir = os.path.join(args.output, "checkpoints")
if args.checkpoint:
iteration = io.load_parameters(
model, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
assert iteration > 0, "A trained model is needed."
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
@ -129,4 +145,8 @@ if __name__ == "__main__":
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
eval_model(model, valid_loader, args.output, sample_rate)
synthesis_dir = os.path.join(args.output, "synthesis")
if not os.path.exists(synthesis_dir):
os.makedirs(synthesis_dir)
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)

View File

@ -16,7 +16,7 @@ from __future__ import division
import os
import ruamel.yaml
import argparse
from tqdm import tqdm
import tqdm
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
@ -24,30 +24,37 @@ import paddle.fluid.dygraph as dg
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from parakeet.utils import io
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, save_checkpoint
from utils import make_output_tree, valid_model
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a wavenet model with LJSpeech.")
description="Train a WaveNet model with LJSpeech.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
"--data", type=str, help="path of the LJspeech dataset")
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--device", type=int, default=-1, help="device to use")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"--output",
type=str,
default="experiment",
help="path to save results.")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument(
"--resume", type=str, help="checkpoint to resume from.")
"output", type=str, default="experiment", help="path to save results")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
print("Command Line Args: ")
for k, v in vars(args).items():
print("{}: {}".format(k, v))
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
@ -126,14 +133,6 @@ if __name__ == "__main__":
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
if args.resume:
model_dict, optim_dict = dg.load_dygraph(args.resume)
print("Loading from {}.pdparams".format(args.resume))
model.set_dict(model_dict)
if optim_dict:
optim.set_dict(optim_dict)
print("Loading from {}.pdopt".format(args.resume))
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
@ -150,10 +149,26 @@ if __name__ == "__main__":
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
global_step = 1
# load parameters and optimizer, and opdate iterations done sofar
if args.checkpoint is not None:
iteration = io.load_parameters(
model, optim, checkpoint_path=args.checkpoint)
else:
iteration = io.load_parameters(
model,
optim,
checkpoint_dir=checkpoint_dir,
iteration=args.iteration)
global_step = iteration + 1
iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iterations:
epoch_loss = 0.
for i, batch in tqdm(enumerate(train_loader)):
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm.tqdm(train_loader))
batch = next(iterator)
audio_clips, mel_specs, audio_starts = batch
model.train()
@ -162,21 +177,20 @@ if __name__ == "__main__":
loss_var.backward()
loss_np = loss_var.numpy()
epoch_loss += loss_np[0]
writer.add_scalar("loss", loss_np[0], global_step)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
optim.minimize(loss_var, grad_clip=clipper)
optim.clear_gradients()
print("loss: {:<8.6f}".format(loss_np[0]))
print("global_step: {}\tloss: {:<8.6f}".format(global_step,
loss_np[0]))
if global_step % snap_interval == 0:
valid_model(model, valid_loader, writer, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
io.save_parameters(checkpoint_dir, global_step, model, optim)
global_step += 1

View File

@ -49,20 +49,14 @@ def valid_model(model, valid_loader, writer, global_step, sample_rate):
sample_rate)
def eval_model(model, valid_loader, output_dir, sample_rate):
def eval_model(model, valid_loader, output_dir, global_step, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
path = os.path.join(output_dir,
"sentence_{}_step_{}.wav".format(i, global_step))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
sf.write(path, wav_np, samplerate=sample_rate)
print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
checkpoint_path = os.path.join(checkpoint_dir,
"step_{:09d}".format(global_step))
dg.save_dygraph(model.state_dict(), checkpoint_path)
dg.save_dygraph(optim.state_dict(), checkpoint_path)

View File

@ -52,8 +52,6 @@ def _load_latest_checkpoint(checkpoint_dir):
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
@ -115,7 +113,6 @@ def load_parameters(model,
if iteration is None:
iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == 0:
# if step-0 exist, it is also loaded
return iteration
checkpoint_path = os.path.join(checkpoint_dir,
"step-{}".format(iteration))