Merge branch 'master' into 'master'
update save & load for deep voicde 3, wavenet and clarinet, remove the concept of epoch in training See merge request !48
This commit is contained in:
commit
a3683ac297
|
@ -22,47 +22,71 @@ tar xjvf LJSpeech-1.1.tar.bz2
|
||||||
└── utils.py utility functions
|
└── utils.py utility functions
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Saving & Loading
|
||||||
|
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
|
||||||
|
|
||||||
|
1. `output` is the directory for saving results.
|
||||||
|
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
|
||||||
|
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
|
||||||
|
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
|
||||||
|
|
||||||
|
```text
|
||||||
|
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
|
||||||
|
├── states/ # audio files generated at validation and other possible outputs
|
||||||
|
├── log/ # tensorboard log
|
||||||
|
└── synthesis/ # synthesized audio files and other possible outputs
|
||||||
|
```
|
||||||
|
|
||||||
|
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
|
||||||
|
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||||
|
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
|
||||||
|
|
||||||
## Train
|
## Train
|
||||||
|
|
||||||
Train the model using train.py, follow the usage displayed by `python train.py --help`.
|
Train the model using train.py, follow the usage displayed by `python train.py --help`.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT]
|
usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
|
||||||
[--data DATA] [--resume RESUME] [--wavenet WAVENET]
|
[--checkpoint CHECKPOINT | --iteration ITERATION]
|
||||||
|
[--wavenet WAVENET]
|
||||||
|
output
|
||||||
|
|
||||||
train a ClariNet model with LJspeech and a trained WaveNet model.
|
Train a ClariNet model with LJspeech and a trained WaveNet model.
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
output path to save experiment results
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--config CONFIG path of the config file.
|
--config CONFIG path of the config file
|
||||||
--device DEVICE device to use.
|
--device DEVICE device to use
|
||||||
--output OUTPUT path to save student.
|
--data DATA path of LJspeech dataset
|
||||||
--data DATA path of LJspeech dataset.
|
--checkpoint CHECKPOINT checkpoint to resume from
|
||||||
--resume RESUME checkpoint to load from.
|
--iteration ITERATION the iteration of the checkpoint to load from output directory
|
||||||
--wavenet WAVENET wavenet checkpoint to use.
|
--wavenet WAVENET wavenet checkpoint to use
|
||||||
```
|
|
||||||
|
|
||||||
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||||
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
|
||||||
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
|
||||||
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
|
|
||||||
|
|
||||||
```text
|
|
||||||
├── checkpoints # checkpoint
|
|
||||||
├── states # audio files generated at validation
|
|
||||||
└── log # tensorboard log
|
|
||||||
```
|
|
||||||
|
|
||||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
- `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided.
|
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains `metadata.txt`).
|
||||||
|
|
||||||
|
- `--checkpoint` is the path of the checkpoint.
|
||||||
|
- `--iteration` is the iteration of the checkpoint to load from output directory.
|
||||||
|
- `output` is the directory to save results, all result are saved in this directory.
|
||||||
|
|
||||||
Before you start training a ClariNet model, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained model.
|
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||||
|
|
||||||
|
- `--wavenet` is the path of the wavenet checkpoint to load.
|
||||||
|
When you start training a ClariNet model without loading form a ClariNet checkpoint, you should have trained a WaveNet model with single Gaussian output distribution. Make sure the config of the teacher model matches that of the trained wavenet model.
|
||||||
|
|
||||||
Example script:
|
Example script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher
|
python train.py
|
||||||
|
--config=./configs/clarinet_ljspeech.yaml
|
||||||
|
--data=./LJSpeech-1.1/
|
||||||
|
--device=0
|
||||||
|
--wavenet="wavenet-step-2000000"
|
||||||
|
experiment
|
||||||
```
|
```
|
||||||
|
|
||||||
You can monitor training log via tensorboard, using the script below.
|
You can monitor training log via tensorboard, using the script below.
|
||||||
|
@ -75,29 +99,50 @@ tensorboard --logdir=.
|
||||||
## Synthesis
|
## Synthesis
|
||||||
```text
|
```text
|
||||||
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
|
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA]
|
||||||
checkpoint output
|
[--checkpoint CHECKPOINT | --iteration ITERATION]
|
||||||
|
output
|
||||||
|
|
||||||
train a ClariNet model with LJspeech and a trained WaveNet model.
|
Synthesize audio files from mel spectrogram in the validation set.
|
||||||
|
|
||||||
positional arguments:
|
positional arguments:
|
||||||
checkpoint checkpoint to load from.
|
output path to save the synthesized audio
|
||||||
output path to save student.
|
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--config CONFIG path of the config file.
|
--config CONFIG path of the config file
|
||||||
--device DEVICE device to use.
|
--device DEVICE device to use.
|
||||||
--data DATA path of LJspeech dataset.
|
--data DATA path of LJspeech dataset
|
||||||
|
--checkpoint CHECKPOINT checkpoint to resume from
|
||||||
|
--iteration ITERATION the iteration of the checkpoint to load from output directory
|
||||||
```
|
```
|
||||||
|
|
||||||
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
||||||
- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
|
|
||||||
- `checkpoint` is the checkpoint to load.
|
|
||||||
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
|
|
||||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
|
- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
|
||||||
|
- `--checkpoint` is the checkpoint to load.
|
||||||
|
- `--iteration` is the iteration of the checkpoint to load from output directory.
|
||||||
|
- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory.
|
||||||
|
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||||
|
|
||||||
|
|
||||||
Example script:
|
Example script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
|
python synthesis.py \
|
||||||
|
--config=./configs/wavenet_single_gaussian.yaml \
|
||||||
|
--data=./LJSpeech-1.1/ \
|
||||||
|
--device=0 \
|
||||||
|
--iteration=500000 \
|
||||||
|
experiment
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python synthesis.py \
|
||||||
|
--config=./configs/wavenet_single_gaussian.yaml \
|
||||||
|
--data=./LJSpeech-1.1/ \
|
||||||
|
--device=0 \
|
||||||
|
--checkpoint="experiment/checkpoints/step-500000" \
|
||||||
|
experiment
|
||||||
```
|
```
|
||||||
|
|
|
@ -26,29 +26,41 @@ from tensorboardX import SummaryWriter
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
||||||
|
from parakeet.modules.weight_norm import WeightNormWrapper
|
||||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||||
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
||||||
from parakeet.utils.layer_tools import summary, freeze
|
from parakeet.utils.layer_tools import summary, freeze
|
||||||
|
from parakeet.utils import io
|
||||||
|
|
||||||
from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model
|
from utils import eval_model
|
||||||
sys.path.append("../wavenet")
|
sys.path.append("../wavenet")
|
||||||
from data import LJSpeechMetaData, Transform, DataCollector
|
from data import LJSpeechMetaData, Transform, DataCollector
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="synthesize audio files from mel spectrogram in the validation set."
|
description="Synthesize audio files from mel spectrogram in the validation set."
|
||||||
)
|
)
|
||||||
parser.add_argument("--config", type=str, help="path of the config file.")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=int, default=-1, help="device to use.")
|
"--device", type=int, default=-1, help="device to use.")
|
||||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
|
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||||
|
|
||||||
|
g = parser.add_mutually_exclusive_group()
|
||||||
|
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||||
|
g.add_argument(
|
||||||
|
"--iteration",
|
||||||
|
type=int,
|
||||||
|
help="the iteration of the checkpoint to load from output directory")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"checkpoint", type=str, help="checkpoint to load from.")
|
"output",
|
||||||
parser.add_argument(
|
type=str,
|
||||||
"output", type=str, default="experiment", help="path to save student.")
|
default="experiment",
|
||||||
|
help="path to save the synthesized audio")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
config = ruamel.yaml.safe_load(f)
|
config = ruamel.yaml.safe_load(f)
|
||||||
|
|
||||||
|
@ -136,17 +148,32 @@ if __name__ == "__main__":
|
||||||
model = Clarinet(upsample_net, teacher, student, stft,
|
model = Clarinet(upsample_net, teacher, student, stft,
|
||||||
student_log_scale_min, lmd)
|
student_log_scale_min, lmd)
|
||||||
summary(model)
|
summary(model)
|
||||||
load_model(model, args.checkpoint)
|
|
||||||
|
|
||||||
# loader
|
# load parameters
|
||||||
train_loader = fluid.io.DataLoader.from_generator(
|
if args.checkpoint is not None:
|
||||||
capacity=10, return_list=True)
|
# load from args.checkpoint
|
||||||
train_loader.set_batch_generator(train_cargo, place)
|
iteration = io.load_parameters(
|
||||||
|
model, checkpoint_path=args.checkpoint)
|
||||||
|
else:
|
||||||
|
# load from "args.output/checkpoints"
|
||||||
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||||
|
assert iteration > 0, "A trained checkpoint is needed."
|
||||||
|
|
||||||
|
# make generation fast
|
||||||
|
for sublayer in model.sublayers():
|
||||||
|
if isinstance(sublayer, WeightNormWrapper):
|
||||||
|
sublayer.remove_weight_norm()
|
||||||
|
|
||||||
|
# data loader
|
||||||
valid_loader = fluid.io.DataLoader.from_generator(
|
valid_loader = fluid.io.DataLoader.from_generator(
|
||||||
capacity=10, return_list=True)
|
capacity=10, return_list=True)
|
||||||
valid_loader.set_batch_generator(valid_cargo, place)
|
valid_loader.set_batch_generator(valid_cargo, place)
|
||||||
|
|
||||||
if not os.path.exists(args.output):
|
# the directory to save audio files
|
||||||
os.makedirs(args.output)
|
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||||
eval_model(model, valid_loader, args.output, sample_rate)
|
if not os.path.exists(synthesis_dir):
|
||||||
|
os.makedirs(synthesis_dir)
|
||||||
|
|
||||||
|
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
|
||||||
|
|
|
@ -30,31 +30,46 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||||
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
||||||
from parakeet.utils.layer_tools import summary, freeze
|
from parakeet.utils.layer_tools import summary, freeze
|
||||||
|
from parakeet.utils import io
|
||||||
|
|
||||||
from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet
|
from utils import make_output_tree, eval_model, load_wavenet
|
||||||
|
|
||||||
|
# import dataset from wavenet
|
||||||
sys.path.append("../wavenet")
|
sys.path.append("../wavenet")
|
||||||
from data import LJSpeechMetaData, Transform, DataCollector
|
from data import LJSpeechMetaData, Transform, DataCollector
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="train a clarinet model with LJspeech and a trained wavenet model."
|
description="Train a ClariNet model with LJspeech and a trained WaveNet model."
|
||||||
)
|
)
|
||||||
parser.add_argument("--config", type=str, help="path of the config file.")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
|
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||||
|
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||||
|
|
||||||
|
g = parser.add_mutually_exclusive_group()
|
||||||
|
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||||
|
g.add_argument(
|
||||||
|
"--iteration",
|
||||||
|
type=int,
|
||||||
|
help="the iteration of the checkpoint to load from output directory")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=int, default=-1, help="device to use.")
|
"--wavenet", type=str, help="wavenet checkpoint to use")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"output",
|
||||||
type=str,
|
type=str,
|
||||||
default="experiment",
|
default="experiment",
|
||||||
help="path to save student.")
|
help="path to save experiment results")
|
||||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset.")
|
|
||||||
parser.add_argument("--resume", type=str, help="checkpoint to load from.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--wavenet", type=str, help="wavenet checkpoint to use.")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
config = ruamel.yaml.safe_load(f)
|
config = ruamel.yaml.safe_load(f)
|
||||||
|
|
||||||
|
print("Command Line args: ")
|
||||||
|
for k, v in vars(args).items():
|
||||||
|
print("{}: {}".format(k, v))
|
||||||
|
|
||||||
ljspeech_meta = LJSpeechMetaData(args.data)
|
ljspeech_meta = LJSpeechMetaData(args.data)
|
||||||
|
|
||||||
data_config = config["data"]
|
data_config = config["data"]
|
||||||
|
@ -154,12 +169,28 @@ if __name__ == "__main__":
|
||||||
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
||||||
gradiant_max_norm)
|
gradiant_max_norm)
|
||||||
|
|
||||||
assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended."
|
# train
|
||||||
if args.wavenet:
|
max_iterations = train_config["max_iterations"]
|
||||||
load_wavenet(model, args.wavenet)
|
checkpoint_interval = train_config["checkpoint_interval"]
|
||||||
|
eval_interval = train_config["eval_interval"]
|
||||||
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
|
state_dir = os.path.join(args.output, "states")
|
||||||
|
log_dir = os.path.join(args.output, "log")
|
||||||
|
writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
if args.resume:
|
if args.checkpoint is not None:
|
||||||
load_checkpoint(model, optim, args.resume)
|
iteration = io.load_parameters(
|
||||||
|
model, optim, checkpoint_path=args.checkpoint)
|
||||||
|
else:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
model,
|
||||||
|
optim,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
iteration=args.iteration)
|
||||||
|
|
||||||
|
if iteration == 0:
|
||||||
|
assert args.wavenet is not None, "When training afresh, a trained wavenet model should be provided."
|
||||||
|
load_wavenet(model, args.wavenet)
|
||||||
|
|
||||||
# loader
|
# loader
|
||||||
train_loader = fluid.io.DataLoader.from_generator(
|
train_loader = fluid.io.DataLoader.from_generator(
|
||||||
|
@ -170,21 +201,16 @@ if __name__ == "__main__":
|
||||||
capacity=10, return_list=True)
|
capacity=10, return_list=True)
|
||||||
valid_loader.set_batch_generator(valid_cargo, place)
|
valid_loader.set_batch_generator(valid_cargo, place)
|
||||||
|
|
||||||
# train
|
|
||||||
max_iterations = train_config["max_iterations"]
|
|
||||||
checkpoint_interval = train_config["checkpoint_interval"]
|
|
||||||
eval_interval = train_config["eval_interval"]
|
|
||||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
|
||||||
state_dir = os.path.join(args.output, "states")
|
|
||||||
log_dir = os.path.join(args.output, "log")
|
|
||||||
writer = SummaryWriter(log_dir)
|
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
global_step = 1
|
global_step = iteration + 1
|
||||||
global_epoch = 1
|
iterator = iter(tqdm(train_loader))
|
||||||
while global_step < max_iterations:
|
while global_step <= max_iterations:
|
||||||
epoch_loss = 0.
|
try:
|
||||||
for j, batch in tqdm(enumerate(train_loader), desc="[train]"):
|
batch = next(iterator)
|
||||||
|
except StopIteration as e:
|
||||||
|
iterator = iter(tqdm(train_loader))
|
||||||
|
batch = next(iterator)
|
||||||
|
|
||||||
audios, mels, audio_starts = batch
|
audios, mels, audio_starts = batch
|
||||||
model.train()
|
model.train()
|
||||||
loss_dict = model(
|
loss_dict = model(
|
||||||
|
@ -199,8 +225,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
l = loss_dict["loss"]
|
l = loss_dict["loss"]
|
||||||
step_loss = l.numpy()[0]
|
step_loss = l.numpy()[0]
|
||||||
print("[train] loss: {:<8.6f}".format(step_loss))
|
print("[train] global_step: {} loss: {:<8.6f}".format(global_step,
|
||||||
epoch_loss += step_loss
|
step_loss))
|
||||||
|
|
||||||
l.backward()
|
l.backward()
|
||||||
optim.minimize(l, grad_clip=clipper)
|
optim.minimize(l, grad_clip=clipper)
|
||||||
|
@ -208,14 +234,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if global_step % eval_interval == 0:
|
if global_step % eval_interval == 0:
|
||||||
# evaluate on valid dataset
|
# evaluate on valid dataset
|
||||||
valid_model(model, valid_loader, state_dir, global_step,
|
eval_model(model, valid_loader, state_dir, global_step,
|
||||||
sample_rate)
|
sample_rate)
|
||||||
if global_step % checkpoint_interval == 0:
|
if global_step % checkpoint_interval == 0:
|
||||||
save_checkpoint(model, optim, checkpoint_dir, global_step)
|
io.save_parameters(checkpoint_dir, global_step, model, optim)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# epoch loss
|
|
||||||
average_loss = epoch_loss / j
|
|
||||||
writer.add_scalar("average_loss", average_loss, global_epoch)
|
|
||||||
global_epoch += 1
|
|
||||||
|
|
|
@ -32,12 +32,12 @@ def make_output_tree(output_dir):
|
||||||
os.makedirs(state_dir)
|
os.makedirs(state_dir)
|
||||||
|
|
||||||
|
|
||||||
def valid_model(model, valid_loader, output_dir, global_step, sample_rate):
|
def eval_model(model, valid_loader, output_dir, iteration, sample_rate):
|
||||||
model.eval()
|
model.eval()
|
||||||
for i, batch in enumerate(valid_loader):
|
for i, batch in enumerate(valid_loader):
|
||||||
# print("sentence {}".format(i))
|
# print("sentence {}".format(i))
|
||||||
path = os.path.join(output_dir,
|
path = os.path.join(output_dir,
|
||||||
"step_{}_sentence_{}.wav".format(global_step, i))
|
"sentence_{}_step_{}.wav".format(i, iteration))
|
||||||
audio_clips, mel_specs, audio_starts = batch
|
audio_clips, mel_specs, audio_starts = batch
|
||||||
wav_var = model.synthesis(mel_specs)
|
wav_var = model.synthesis(mel_specs)
|
||||||
wav_np = wav_var.numpy()[0]
|
wav_np = wav_var.numpy()[0]
|
||||||
|
@ -45,42 +45,6 @@ def valid_model(model, valid_loader, output_dir, global_step, sample_rate):
|
||||||
print("generated {}".format(path))
|
print("generated {}".format(path))
|
||||||
|
|
||||||
|
|
||||||
def eval_model(model, valid_loader, output_dir, sample_rate):
|
|
||||||
model.eval()
|
|
||||||
for i, batch in enumerate(valid_loader):
|
|
||||||
# print("sentence {}".format(i))
|
|
||||||
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
|
|
||||||
audio_clips, mel_specs, audio_starts = batch
|
|
||||||
wav_var = model.synthesis(mel_specs)
|
|
||||||
wav_np = wav_var.numpy()[0]
|
|
||||||
sf.write(path, wav_np, samplerate=sample_rate)
|
|
||||||
print("generated {}".format(path))
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optim, checkpoint_dir, global_step):
|
|
||||||
path = os.path.join(checkpoint_dir, "step_{}".format(global_step))
|
|
||||||
dg.save_dygraph(model.state_dict(), path)
|
|
||||||
print("saving model to {}".format(path + ".pdparams"))
|
|
||||||
if optim:
|
|
||||||
dg.save_dygraph(optim.state_dict(), path)
|
|
||||||
print("saving optimizer to {}".format(path + ".pdopt"))
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model, path):
|
|
||||||
model_dict, _ = dg.load_dygraph(path)
|
|
||||||
model.set_dict(model_dict)
|
|
||||||
print("loaded model from {}.pdparams".format(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, optim, path):
|
|
||||||
model_dict, optim_dict = dg.load_dygraph(path)
|
|
||||||
model.set_dict(model_dict)
|
|
||||||
print("loaded model from {}.pdparams".format(path))
|
|
||||||
if optim_dict:
|
|
||||||
optim.set_dict(optim_dict)
|
|
||||||
print("loaded optimizer from {}.pdparams".format(path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_wavenet(model, path):
|
def load_wavenet(model, path):
|
||||||
wavenet_dict, _ = dg.load_dygraph(path)
|
wavenet_dict, _ = dg.load_dygraph(path)
|
||||||
encoder_dict = OrderedDict()
|
encoder_dict = OrderedDict()
|
||||||
|
|
|
@ -30,32 +30,55 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
|
||||||
└── utils.py utility functions
|
└── utils.py utility functions
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Saving & Loading
|
||||||
|
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
|
||||||
|
|
||||||
|
1. `output` is the directory for saving results.
|
||||||
|
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
|
||||||
|
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
|
||||||
|
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
|
||||||
|
|
||||||
|
```text
|
||||||
|
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
|
||||||
|
├── states/ # audio files generated at validation and other possible outputs
|
||||||
|
├── log/ # tensorboard log
|
||||||
|
└── synthesis/ # synthesized audio files and other possible outputs
|
||||||
|
```
|
||||||
|
|
||||||
|
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
|
||||||
|
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||||
|
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
|
||||||
|
|
||||||
## Train
|
## Train
|
||||||
|
|
||||||
Train the model using train.py, follow the usage displayed by `python train.py --help`.
|
Train the model using train.py, follow the usage displayed by `python train.py --help`.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE]
|
usage: train.py [-h] [--config CONFIG] [--data DATA] [--device DEVICE]
|
||||||
|
[--checkpoint CHECKPOINT | --iteration ITERATION]
|
||||||
|
output
|
||||||
|
|
||||||
Train a Deep Voice 3 model with LJSpeech dataset.
|
Train a Deep Voice 3 model with LJSpeech dataset.
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
output path to save results
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
-c CONFIG, --config CONFIG
|
--config CONFIG experimrnt config
|
||||||
experimrnt config
|
--data DATA The path of the LJSpeech dataset.
|
||||||
-s DATA, --data DATA The path of the LJSpeech dataset.
|
--device DEVICE device to use
|
||||||
-r RESUME, --resume RESUME
|
--checkpoint CHECKPOINT checkpoint to resume from.
|
||||||
checkpoint to load
|
--iteration ITERATION the iteration of the checkpoint to load from output directory
|
||||||
-o OUTPUT, --output OUTPUT
|
|
||||||
The directory to save result.
|
|
||||||
-g DEVICE, --device DEVICE
|
|
||||||
device to use
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
- `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||||
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
||||||
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
|
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
- `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
|
- `--checkpoint` is the path of the checkpoint.
|
||||||
|
- `--iteration` is the iteration of the checkpoint to load from output directory.
|
||||||
|
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||||
|
- `output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
├── checkpoints # checkpoint
|
├── checkpoints # checkpoint
|
||||||
|
@ -67,12 +90,14 @@ optional arguments:
|
||||||
└── waveform # waveform (.wav files)
|
└── waveform # waveform (.wav files)
|
||||||
```
|
```
|
||||||
|
|
||||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
|
||||||
|
|
||||||
Example script:
|
Example script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py --config=configs/ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
|
python train.py \
|
||||||
|
--config=configs/ljspeech.yaml \
|
||||||
|
--data=./LJSpeech-1.1/ \
|
||||||
|
--device=0 \
|
||||||
|
experiment
|
||||||
```
|
```
|
||||||
|
|
||||||
You can monitor training log via tensorboard, using the script below.
|
You can monitor training log via tensorboard, using the script below.
|
||||||
|
@ -84,31 +109,50 @@ tensorboard --logdir=.
|
||||||
|
|
||||||
## Synthesis
|
## Synthesis
|
||||||
```text
|
```text
|
||||||
usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path
|
usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE]
|
||||||
|
[--checkpoint CHECKPOINT | --iteration ITERATION]
|
||||||
|
text output
|
||||||
|
|
||||||
Synthsize waveform from a checkpoint.
|
Synthsize waveform with a checkpoint.
|
||||||
|
|
||||||
positional arguments:
|
positional arguments:
|
||||||
checkpoint checkpoint to load.
|
|
||||||
text text file to synthesize
|
text text file to synthesize
|
||||||
output_path path to save results
|
output path to save synthesized audio
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
-c CONFIG, --config CONFIG
|
--config CONFIG experiment config
|
||||||
experiment config.
|
--device DEVICE device to use
|
||||||
-g DEVICE, --device DEVICE
|
--checkpoint CHECKPOINT checkpoint to resume from
|
||||||
device to use
|
--iteration ITERATION the iteration of the checkpoint to load from output directory
|
||||||
```
|
```
|
||||||
|
|
||||||
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
||||||
- `checkpoint` is the checkpoint to load.
|
|
||||||
- `text`is the text file to synthesize.
|
|
||||||
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence.
|
|
||||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
|
|
||||||
|
- `--checkpoint` is the path of the checkpoint.
|
||||||
|
- `--iteration` is the iteration of the checkpoint to load from output directory.
|
||||||
|
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||||
|
|
||||||
|
- `text`is the text file to synthesize.
|
||||||
|
- `output` is the directory to save results. The generated audio files (`*.wav`) and attention plots (*.png) for are save in `synthesis/` in ouput directory.
|
||||||
|
|
||||||
Example script:
|
Example script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python synthesis.py --config=configs/ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated
|
python synthesis.py \
|
||||||
|
--config=configs/ljspeech.yaml \
|
||||||
|
--device=0 \
|
||||||
|
--checkpoint="experiment/checkpoints/model_step_005000000" \
|
||||||
|
sentences.txt experiment
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python synthesis.py \
|
||||||
|
--config=configs/ljspeech.yaml \
|
||||||
|
--device=0 \
|
||||||
|
--iteration=005000000 \
|
||||||
|
sentences.txt experiment
|
||||||
```
|
```
|
||||||
|
|
|
@ -83,7 +83,7 @@ lr_scheduler:
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
epochs: 2000
|
max_iteration: 2000000
|
||||||
|
|
||||||
snap_interval: 1000
|
snap_interval: 1000
|
||||||
eval_interval: 10000
|
eval_interval: 10000
|
||||||
|
|
|
@ -25,25 +25,37 @@ import paddle.fluid.dygraph as dg
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from parakeet.g2p import en
|
from parakeet.g2p import en
|
||||||
from parakeet.utils.layer_tools import summary
|
|
||||||
from parakeet.modules.weight_norm import WeightNormWrapper
|
from parakeet.modules.weight_norm import WeightNormWrapper
|
||||||
|
from parakeet.utils.layer_tools import summary
|
||||||
|
from parakeet.utils import io
|
||||||
|
|
||||||
from utils import make_model, eval_model, plot_alignment
|
from utils import make_model, eval_model, plot_alignment
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthsize waveform with a checkpoint.")
|
description="Synthsize waveform with a checkpoint.")
|
||||||
parser.add_argument("-c", "--config", type=str, help="experiment config.")
|
parser.add_argument("--config", type=str, help="experiment config")
|
||||||
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
|
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||||
|
|
||||||
|
g = parser.add_mutually_exclusive_group()
|
||||||
|
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||||
|
g.add_argument(
|
||||||
|
"--iteration",
|
||||||
|
type=int,
|
||||||
|
help="the iteration of the checkpoint to load from output directory")
|
||||||
|
|
||||||
parser.add_argument("text", type=str, help="text file to synthesize")
|
parser.add_argument("text", type=str, help="text file to synthesize")
|
||||||
parser.add_argument("output_path", type=str, help="path to save results")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-g", "--device", type=int, default=-1, help="device to use")
|
"output", type=str, help="path to save synthesized audio")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
config = ruamel.yaml.safe_load(f)
|
config = ruamel.yaml.safe_load(f)
|
||||||
|
|
||||||
|
print("Command Line Args: ")
|
||||||
|
for k, v in vars(args).items():
|
||||||
|
print("{}: {}".format(k, v))
|
||||||
|
|
||||||
if args.device == -1:
|
if args.device == -1:
|
||||||
place = fluid.CPUPlace()
|
place = fluid.CPUPlace()
|
||||||
else:
|
else:
|
||||||
|
@ -98,8 +110,14 @@ if __name__ == "__main__":
|
||||||
linear_dim, use_decoder_states, converter_channels, dropout)
|
linear_dim, use_decoder_states, converter_channels, dropout)
|
||||||
|
|
||||||
summary(dv3)
|
summary(dv3)
|
||||||
state, _ = dg.load_dygraph(args.checkpoint)
|
|
||||||
dv3.set_dict(state)
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
|
if args.checkpoint is not None:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
dv3, checkpoint_path=args.checkpoint)
|
||||||
|
else:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
dv3, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||||
|
|
||||||
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
||||||
# removing weight norm also speeds up computation
|
# removing weight norm also speeds up computation
|
||||||
|
@ -107,9 +125,6 @@ if __name__ == "__main__":
|
||||||
if isinstance(layer, WeightNormWrapper):
|
if isinstance(layer, WeightNormWrapper):
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
if not os.path.exists(args.output_path):
|
|
||||||
os.makedirs(args.output_path)
|
|
||||||
|
|
||||||
transform_config = config["transform"]
|
transform_config = config["transform"]
|
||||||
c = transform_config["replace_pronunciation_prob"]
|
c = transform_config["replace_pronunciation_prob"]
|
||||||
sample_rate = transform_config["sample_rate"]
|
sample_rate = transform_config["sample_rate"]
|
||||||
|
@ -123,6 +138,10 @@ if __name__ == "__main__":
|
||||||
power = synthesis_config["power"]
|
power = synthesis_config["power"]
|
||||||
n_iter = synthesis_config["n_iter"]
|
n_iter = synthesis_config["n_iter"]
|
||||||
|
|
||||||
|
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||||
|
if not os.path.exists(synthesis_dir):
|
||||||
|
os.makedirs(synthesis_dir)
|
||||||
|
|
||||||
with open(args.text, "rt", encoding="utf-8") as f:
|
with open(args.text, "rt", encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
for idx, line in enumerate(lines):
|
for idx, line in enumerate(lines):
|
||||||
|
@ -134,7 +153,9 @@ if __name__ == "__main__":
|
||||||
preemphasis)
|
preemphasis)
|
||||||
plot_alignment(
|
plot_alignment(
|
||||||
attn,
|
attn,
|
||||||
os.path.join(args.output_path, "test_{}.png".format(idx)))
|
os.path.join(synthesis_dir,
|
||||||
|
"test_{}_step_{}.png".format(idx, iteration)))
|
||||||
sf.write(
|
sf.write(
|
||||||
os.path.join(args.output_path, "test_{}.wav".format(idx)),
|
os.path.join(synthesis_dir,
|
||||||
|
"test_{}_step{}.wav".format(idx, iteration)),
|
||||||
wav, sample_rate)
|
wav, sample_rate)
|
||||||
|
|
|
@ -17,6 +17,8 @@ import os
|
||||||
import argparse
|
import argparse
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("agg")
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -35,33 +37,40 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler,
|
||||||
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
|
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
|
||||||
from parakeet.models.deepvoice3.loss import TTSLoss
|
from parakeet.models.deepvoice3.loss import TTSLoss
|
||||||
from parakeet.utils.layer_tools import summary
|
from parakeet.utils.layer_tools import summary
|
||||||
|
from parakeet.utils import io
|
||||||
|
|
||||||
from data import LJSpeechMetaData, DataCollector, Transform
|
from data import LJSpeechMetaData, DataCollector, Transform
|
||||||
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
|
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Train a deepvoice 3 model with LJSpeech dataset.")
|
description="Train a Deep Voice 3 model with LJSpeech dataset.")
|
||||||
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
|
parser.add_argument("--config", type=str, help="experimrnt config")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s",
|
|
||||||
"--data",
|
"--data",
|
||||||
type=str,
|
type=str,
|
||||||
default="/workspace/datasets/LJSpeech-1.1/",
|
default="/workspace/datasets/LJSpeech-1.1/",
|
||||||
help="The path of the LJSpeech dataset.")
|
help="The path of the LJSpeech dataset.")
|
||||||
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
|
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||||
|
|
||||||
|
g = parser.add_mutually_exclusive_group()
|
||||||
|
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.")
|
||||||
|
g.add_argument(
|
||||||
|
"--iteration",
|
||||||
|
type=int,
|
||||||
|
help="the iteration of the checkpoint to load from output directory")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-o",
|
"output", type=str, default="experiment", help="path to save results")
|
||||||
"--output",
|
|
||||||
type=str,
|
|
||||||
default="result",
|
|
||||||
help="The directory to save result.")
|
|
||||||
parser.add_argument(
|
|
||||||
"-g", "--device", type=int, default=-1, help="device to use")
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
config = ruamel.yaml.safe_load(f)
|
config = ruamel.yaml.safe_load(f)
|
||||||
|
|
||||||
|
print("Command Line Args: ")
|
||||||
|
for k, v in vars(args).items():
|
||||||
|
print("{}: {}".format(k, v))
|
||||||
|
|
||||||
# =========================dataset=========================
|
# =========================dataset=========================
|
||||||
# construct meta data
|
# construct meta data
|
||||||
data_root = args.data
|
data_root = args.data
|
||||||
|
@ -151,6 +160,7 @@ if __name__ == "__main__":
|
||||||
query_position_rate, key_position_rate, window_backward,
|
query_position_rate, key_position_rate, window_backward,
|
||||||
window_ahead, key_projection, value_projection, downsample_factor,
|
window_ahead, key_projection, value_projection, downsample_factor,
|
||||||
linear_dim, use_decoder_states, converter_channels, dropout)
|
linear_dim, use_decoder_states, converter_channels, dropout)
|
||||||
|
summary(dv3)
|
||||||
|
|
||||||
# =========================loss=========================
|
# =========================loss=========================
|
||||||
loss_config = config["loss"]
|
loss_config = config["loss"]
|
||||||
|
@ -195,7 +205,6 @@ if __name__ == "__main__":
|
||||||
n_iter = synthesis_config["n_iter"]
|
n_iter = synthesis_config["n_iter"]
|
||||||
|
|
||||||
# =========================link(dataloader, paddle)=========================
|
# =========================link(dataloader, paddle)=========================
|
||||||
# CAUTION: it does not return a DataLoader
|
|
||||||
loader = fluid.io.DataLoader.from_generator(
|
loader = fluid.io.DataLoader.from_generator(
|
||||||
capacity=10, return_list=True)
|
capacity=10, return_list=True)
|
||||||
loader.set_batch_generator(ljspeech_loader, places=place)
|
loader.set_batch_generator(ljspeech_loader, places=place)
|
||||||
|
@ -208,24 +217,30 @@ if __name__ == "__main__":
|
||||||
make_output_tree(output_dir)
|
make_output_tree(output_dir)
|
||||||
writer = SummaryWriter(logdir=log_dir)
|
writer = SummaryWriter(logdir=log_dir)
|
||||||
|
|
||||||
# load model parameters
|
# load parameters and optimizer, and opdate iterations done sofar
|
||||||
resume_path = args.resume
|
if args.checkpoint is not None:
|
||||||
if resume_path is not None:
|
iteration = io.load_parameters(
|
||||||
state, _ = dg.load_dygraph(args.resume)
|
dv3, optim, checkpoint_path=args.checkpoint)
|
||||||
dv3.set_dict(state)
|
else:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
dv3, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration)
|
||||||
|
|
||||||
# =========================train=========================
|
# =========================train=========================
|
||||||
epoch = train_config["epochs"]
|
max_iter = train_config["max_iteration"]
|
||||||
snap_interval = train_config["snap_interval"]
|
snap_interval = train_config["snap_interval"]
|
||||||
save_interval = train_config["save_interval"]
|
save_interval = train_config["save_interval"]
|
||||||
eval_interval = train_config["eval_interval"]
|
eval_interval = train_config["eval_interval"]
|
||||||
|
|
||||||
global_step = 1
|
global_step = iteration + 1
|
||||||
|
iterator = iter(tqdm.tqdm(loader))
|
||||||
|
while global_step <= max_iter:
|
||||||
|
try:
|
||||||
|
batch = next(iterator)
|
||||||
|
except StopIteration as e:
|
||||||
|
iterator = iter(tqdm.tqdm(loader))
|
||||||
|
batch = next(iterator)
|
||||||
|
|
||||||
for j in range(1, 1 + epoch):
|
dv3.train()
|
||||||
epoch_loss = 0.
|
|
||||||
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
|
|
||||||
dv3.train() # CAUTION: don't forget to switch to train
|
|
||||||
(text_sequences, text_lengths, text_positions, mel_specs,
|
(text_sequences, text_lengths, text_positions, mel_specs,
|
||||||
lin_specs, frames, decoder_positions, done_flags) = batch
|
lin_specs, frames, decoder_positions, done_flags) = batch
|
||||||
downsampled_mel_specs = F.strided_slice(
|
downsampled_mel_specs = F.strided_slice(
|
||||||
|
@ -238,26 +253,25 @@ if __name__ == "__main__":
|
||||||
text_sequences, text_positions, text_lengths, None,
|
text_sequences, text_positions, text_lengths, None,
|
||||||
downsampled_mel_specs, decoder_positions)
|
downsampled_mel_specs, decoder_positions)
|
||||||
|
|
||||||
losses = criterion(mel_outputs, linear_outputs, done,
|
losses = criterion(mel_outputs, linear_outputs, done, alignments,
|
||||||
alignments, downsampled_mel_specs,
|
downsampled_mel_specs, lin_specs, done_flags,
|
||||||
lin_specs, done_flags, text_lengths, frames)
|
text_lengths, frames)
|
||||||
l = losses["loss"]
|
l = losses["loss"]
|
||||||
l.backward()
|
l.backward()
|
||||||
# record learning rate before updating
|
# record learning rate before updating
|
||||||
writer.add_scalar("learning_rate",
|
writer.add_scalar("learning_rate",
|
||||||
optim._learning_rate.step().numpy(),
|
optim._learning_rate.step().numpy(), global_step)
|
||||||
global_step)
|
|
||||||
optim.minimize(l, grad_clip=gradient_clipper)
|
optim.minimize(l, grad_clip=gradient_clipper)
|
||||||
optim.clear_gradients()
|
optim.clear_gradients()
|
||||||
|
|
||||||
# ==================all kinds of tedious things=================
|
# ==================all kinds of tedious things=================
|
||||||
# record step loss into tensorboard
|
# record step loss into tensorboard
|
||||||
epoch_loss += l.numpy()[0]
|
|
||||||
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
|
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
|
||||||
|
tqdm.tqdm.write("global_step: {}\tloss: {}".format(
|
||||||
|
global_step, step_loss["loss"]))
|
||||||
for k, v in step_loss.items():
|
for k, v in step_loss.items():
|
||||||
writer.add_scalar(k, v, global_step)
|
writer.add_scalar(k, v, global_step)
|
||||||
|
|
||||||
# TODO: clean code
|
|
||||||
# train state saving, the first sentence in the batch
|
# train state saving, the first sentence in the batch
|
||||||
if global_step % snap_interval == 0:
|
if global_step % snap_interval == 0:
|
||||||
save_state(
|
save_state(
|
||||||
|
@ -290,9 +304,9 @@ if __name__ == "__main__":
|
||||||
]
|
]
|
||||||
for idx, sent in enumerate(sentences):
|
for idx, sent in enumerate(sentences):
|
||||||
wav, attn = eval_model(
|
wav, attn = eval_model(
|
||||||
dv3, sent, replace_pronounciation_prob,
|
dv3, sent, replace_pronounciation_prob, min_level_db,
|
||||||
min_level_db, ref_level_db, power, n_iter,
|
ref_level_db, power, n_iter, win_length, hop_length,
|
||||||
win_length, hop_length, preemphasis)
|
preemphasis)
|
||||||
wav_path = os.path.join(
|
wav_path = os.path.join(
|
||||||
state_dir, "waveform",
|
state_dir, "waveform",
|
||||||
"eval_sample_{:09d}.wav".format(global_step))
|
"eval_sample_{:09d}.wav".format(global_step))
|
||||||
|
@ -314,16 +328,6 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if global_step % save_interval == 0:
|
if global_step % save_interval == 0:
|
||||||
dg.save_dygraph(
|
io.save_parameters(ckpt_dir, global_step, dv3, optim)
|
||||||
dv3.state_dict(),
|
|
||||||
os.path.join(ckpt_dir,
|
|
||||||
"model_step_{}".format(global_step)))
|
|
||||||
dg.save_dygraph(
|
|
||||||
optim.state_dict(),
|
|
||||||
os.path.join(ckpt_dir,
|
|
||||||
"model_step_{}".format(global_step)))
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
# epoch report
|
|
||||||
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
|
|
||||||
epoch_loss = 0.
|
|
||||||
|
|
|
@ -22,41 +22,67 @@ tar xjvf LJSpeech-1.1.tar.bz2
|
||||||
└── utils.py utility functions
|
└── utils.py utility functions
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Saving & Loading
|
||||||
|
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
|
||||||
|
|
||||||
|
1. `output` is the directory for saving results.
|
||||||
|
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
|
||||||
|
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
|
||||||
|
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
|
||||||
|
|
||||||
|
```text
|
||||||
|
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
|
||||||
|
├── states/ # audio files generated at validation and other possible outputs
|
||||||
|
├── log/ # tensorboard log
|
||||||
|
└── synthesis/ # synthesized audio files and other possible outputs
|
||||||
|
```
|
||||||
|
|
||||||
|
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
|
||||||
|
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||||
|
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
|
||||||
|
|
||||||
## Train
|
## Train
|
||||||
|
|
||||||
Train the model using train.py. For help on usage, try `python train.py --help`.
|
Train the model using train.py. For help on usage, try `python train.py --help`.
|
||||||
|
|
||||||
```text
|
```text
|
||||||
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
|
usage: train.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
|
||||||
[--device DEVICE] [--resume RESUME]
|
[--checkpoint CHECKPOINT | --iteration ITERATION]
|
||||||
|
output
|
||||||
|
|
||||||
Train a WaveNet model with LJSpeech.
|
Train a WaveNet model with LJSpeech.
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
output path to save results
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--data DATA path of the LJspeech dataset.
|
--data DATA path of the LJspeech dataset
|
||||||
--config CONFIG path of the config file.
|
--config CONFIG path of the config file
|
||||||
--output OUTPUT path to save results.
|
--device DEVICE device to use
|
||||||
--device DEVICE device to use.
|
--checkpoint CHECKPOINT checkpoint to resume from
|
||||||
--resume RESUME checkpoint to resume from.
|
--iteration ITERATION the iteration of the checkpoint to load from output directory
|
||||||
```
|
```
|
||||||
|
|
||||||
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
|
||||||
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
- `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
|
||||||
- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training.
|
- `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
|
||||||
- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
|
|
||||||
|
|
||||||
```text
|
|
||||||
├── checkpoints # checkpoint
|
|
||||||
└── log # tensorboard log
|
|
||||||
```
|
|
||||||
|
|
||||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
|
|
||||||
|
- `--checkpoint` is the path of the checkpoint.
|
||||||
|
- `--iteration` is the iteration of the checkpoint to load from output directory.
|
||||||
|
- `output` is the directory to save results, all result are saved in this directory.
|
||||||
|
|
||||||
|
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||||
|
|
||||||
|
|
||||||
Example script:
|
Example script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
|
python train.py \
|
||||||
|
--config=./configs/wavenet_single_gaussian.yaml \
|
||||||
|
--data=./LJSpeech-1.1/ \
|
||||||
|
--device=0 \
|
||||||
|
experiment
|
||||||
```
|
```
|
||||||
|
|
||||||
You can monitor training log via TensorBoard, using the script below.
|
You can monitor training log via TensorBoard, using the script below.
|
||||||
|
@ -69,29 +95,50 @@ tensorboard --logdir=.
|
||||||
## Synthesis
|
## Synthesis
|
||||||
```text
|
```text
|
||||||
usage: synthesis.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
|
usage: synthesis.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
|
||||||
checkpoint output
|
[--checkpoint CHECKPOINT | --iteration ITERATION]
|
||||||
|
output
|
||||||
|
|
||||||
Synthesize valid data from LJspeech with a WaveNet model.
|
Synthesize valid data from LJspeech with a wavenet model.
|
||||||
|
|
||||||
positional arguments:
|
positional arguments:
|
||||||
checkpoint checkpoint to load.
|
output path to save the synthesized audio
|
||||||
output path to save results.
|
|
||||||
|
|
||||||
optional arguments:
|
optional arguments:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--data DATA path of the LJspeech dataset.
|
--data DATA path of the LJspeech dataset
|
||||||
--config CONFIG path of the config file.
|
--config CONFIG path of the config file
|
||||||
--device DEVICE device to use.
|
--device DEVICE device to use
|
||||||
|
--checkpoint CHECKPOINT checkpoint to resume from
|
||||||
|
--iteration ITERATION the iteration of the checkpoint to load from output directory
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- `--data` is the path of the LJspeech dataset. In principle, a dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
|
||||||
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
- `--config` is the configuration file to use. You should use the same configuration with which you train you model.
|
||||||
- `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
|
|
||||||
- `checkpoint` is the checkpoint to load.
|
|
||||||
- `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
|
|
||||||
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
- `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
|
- `--checkpoint` is the checkpoint to load.
|
||||||
|
- `--iteration` is the iteration of the checkpoint to load from output directory.
|
||||||
|
- `output` is the directory to save synthesized audio. Audio file is saved in `synthesis/` in `output` directory.
|
||||||
|
See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||||
|
|
||||||
|
|
||||||
Example script:
|
Example script:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
|
python synthesis.py \
|
||||||
|
--config=./configs/wavenet_single_gaussian.yaml \
|
||||||
|
--data=./LJSpeech-1.1/ \
|
||||||
|
--device=0 \
|
||||||
|
--checkpoint="experiment/checkpoints/step-1000000" \
|
||||||
|
experiment
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python synthesis.py \
|
||||||
|
--config=./configs/wavenet_single_gaussian.yaml \
|
||||||
|
--data=./LJSpeech-1.1/ \
|
||||||
|
--device=0 \
|
||||||
|
--iteration=1000000 \
|
||||||
|
experiment
|
||||||
```
|
```
|
||||||
|
|
|
@ -25,22 +25,31 @@ from parakeet.modules.weight_norm import WeightNormWrapper
|
||||||
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
|
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
|
||||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
||||||
from parakeet.utils.layer_tools import summary
|
from parakeet.utils.layer_tools import summary
|
||||||
|
from parakeet.utils import io
|
||||||
|
|
||||||
from data import LJSpeechMetaData, Transform, DataCollector
|
from data import LJSpeechMetaData, Transform, DataCollector
|
||||||
from utils import make_output_tree, valid_model, eval_model, save_checkpoint
|
from utils import make_output_tree, valid_model, eval_model
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize valid data from LJspeech with a wavenet model.")
|
description="Synthesize valid data from LJspeech with a wavenet model.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data", type=str, help="path of the LJspeech dataset.")
|
"--data", type=str, help="path of the LJspeech dataset")
|
||||||
parser.add_argument("--config", type=str, help="path of the config file.")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
parser.add_argument(
|
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||||
"--device", type=int, default=-1, help="device to use.")
|
|
||||||
|
g = parser.add_mutually_exclusive_group()
|
||||||
|
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||||
|
g.add_argument(
|
||||||
|
"--iteration",
|
||||||
|
type=int,
|
||||||
|
help="the iteration of the checkpoint to load from output directory")
|
||||||
|
|
||||||
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"output", type=str, default="experiment", help="path to save results.")
|
"output",
|
||||||
|
type=str,
|
||||||
|
default="experiment",
|
||||||
|
help="path to save the synthesized audio")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
|
@ -87,7 +96,8 @@ if __name__ == "__main__":
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
sampler=SequentialSampler(ljspeech_valid))
|
sampler=SequentialSampler(ljspeech_valid))
|
||||||
|
|
||||||
make_output_tree(args.output)
|
if not os.path.exists(args.output):
|
||||||
|
os.makedirs(args.output)
|
||||||
|
|
||||||
if args.device == -1:
|
if args.device == -1:
|
||||||
place = fluid.CPUPlace()
|
place = fluid.CPUPlace()
|
||||||
|
@ -111,9 +121,15 @@ if __name__ == "__main__":
|
||||||
model = ConditionalWavenet(encoder, decoder)
|
model = ConditionalWavenet(encoder, decoder)
|
||||||
summary(model)
|
summary(model)
|
||||||
|
|
||||||
model_dict, _ = dg.load_dygraph(args.checkpoint)
|
# load model parameters
|
||||||
print("Loading from {}.pdparams".format(args.checkpoint))
|
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||||
model.set_dict(model_dict)
|
if args.checkpoint:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
model, checkpoint_path=args.checkpoint)
|
||||||
|
else:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||||
|
assert iteration > 0, "A trained model is needed."
|
||||||
|
|
||||||
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
||||||
# removing weight norm also speeds up computation
|
# removing weight norm also speeds up computation
|
||||||
|
@ -129,4 +145,8 @@ if __name__ == "__main__":
|
||||||
capacity=10, return_list=True)
|
capacity=10, return_list=True)
|
||||||
valid_loader.set_batch_generator(valid_cargo, place)
|
valid_loader.set_batch_generator(valid_cargo, place)
|
||||||
|
|
||||||
eval_model(model, valid_loader, args.output, sample_rate)
|
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||||
|
if not os.path.exists(synthesis_dir):
|
||||||
|
os.makedirs(synthesis_dir)
|
||||||
|
|
||||||
|
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from __future__ import division
|
||||||
import os
|
import os
|
||||||
import ruamel.yaml
|
import ruamel.yaml
|
||||||
import argparse
|
import argparse
|
||||||
from tqdm import tqdm
|
import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
@ -24,30 +24,37 @@ import paddle.fluid.dygraph as dg
|
||||||
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
|
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
|
||||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
||||||
from parakeet.utils.layer_tools import summary
|
from parakeet.utils.layer_tools import summary
|
||||||
|
from parakeet.utils import io
|
||||||
|
|
||||||
from data import LJSpeechMetaData, Transform, DataCollector
|
from data import LJSpeechMetaData, Transform, DataCollector
|
||||||
from utils import make_output_tree, valid_model, save_checkpoint
|
from utils import make_output_tree, valid_model
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Train a wavenet model with LJSpeech.")
|
description="Train a WaveNet model with LJSpeech.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data", type=str, help="path of the LJspeech dataset.")
|
"--data", type=str, help="path of the LJspeech dataset")
|
||||||
parser.add_argument("--config", type=str, help="path of the config file.")
|
parser.add_argument("--config", type=str, help="path of the config file")
|
||||||
|
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||||
|
|
||||||
|
g = parser.add_mutually_exclusive_group()
|
||||||
|
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||||
|
g.add_argument(
|
||||||
|
"--iteration",
|
||||||
|
type=int,
|
||||||
|
help="the iteration of the checkpoint to load from output directory")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"output", type=str, default="experiment", help="path to save results")
|
||||||
type=str,
|
|
||||||
default="experiment",
|
|
||||||
help="path to save results.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--device", type=int, default=-1, help="device to use.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--resume", type=str, help="checkpoint to resume from.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open(args.config, 'rt') as f:
|
with open(args.config, 'rt') as f:
|
||||||
config = ruamel.yaml.safe_load(f)
|
config = ruamel.yaml.safe_load(f)
|
||||||
|
|
||||||
|
print("Command Line Args: ")
|
||||||
|
for k, v in vars(args).items():
|
||||||
|
print("{}: {}".format(k, v))
|
||||||
|
|
||||||
ljspeech_meta = LJSpeechMetaData(args.data)
|
ljspeech_meta = LJSpeechMetaData(args.data)
|
||||||
|
|
||||||
data_config = config["data"]
|
data_config = config["data"]
|
||||||
|
@ -126,14 +133,6 @@ if __name__ == "__main__":
|
||||||
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
||||||
gradiant_max_norm)
|
gradiant_max_norm)
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
model_dict, optim_dict = dg.load_dygraph(args.resume)
|
|
||||||
print("Loading from {}.pdparams".format(args.resume))
|
|
||||||
model.set_dict(model_dict)
|
|
||||||
if optim_dict:
|
|
||||||
optim.set_dict(optim_dict)
|
|
||||||
print("Loading from {}.pdopt".format(args.resume))
|
|
||||||
|
|
||||||
train_loader = fluid.io.DataLoader.from_generator(
|
train_loader = fluid.io.DataLoader.from_generator(
|
||||||
capacity=10, return_list=True)
|
capacity=10, return_list=True)
|
||||||
train_loader.set_batch_generator(train_cargo, place)
|
train_loader.set_batch_generator(train_cargo, place)
|
||||||
|
@ -150,10 +149,26 @@ if __name__ == "__main__":
|
||||||
log_dir = os.path.join(args.output, "log")
|
log_dir = os.path.join(args.output, "log")
|
||||||
writer = SummaryWriter(log_dir)
|
writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
global_step = 1
|
# load parameters and optimizer, and opdate iterations done sofar
|
||||||
|
if args.checkpoint is not None:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
model, optim, checkpoint_path=args.checkpoint)
|
||||||
|
else:
|
||||||
|
iteration = io.load_parameters(
|
||||||
|
model,
|
||||||
|
optim,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
iteration=args.iteration)
|
||||||
|
|
||||||
|
global_step = iteration + 1
|
||||||
|
iterator = iter(tqdm.tqdm(train_loader))
|
||||||
while global_step <= max_iterations:
|
while global_step <= max_iterations:
|
||||||
epoch_loss = 0.
|
try:
|
||||||
for i, batch in tqdm(enumerate(train_loader)):
|
batch = next(iterator)
|
||||||
|
except StopIteration as e:
|
||||||
|
iterator = iter(tqdm.tqdm(train_loader))
|
||||||
|
batch = next(iterator)
|
||||||
|
|
||||||
audio_clips, mel_specs, audio_starts = batch
|
audio_clips, mel_specs, audio_starts = batch
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -162,21 +177,20 @@ if __name__ == "__main__":
|
||||||
loss_var.backward()
|
loss_var.backward()
|
||||||
loss_np = loss_var.numpy()
|
loss_np = loss_var.numpy()
|
||||||
|
|
||||||
epoch_loss += loss_np[0]
|
|
||||||
|
|
||||||
writer.add_scalar("loss", loss_np[0], global_step)
|
writer.add_scalar("loss", loss_np[0], global_step)
|
||||||
writer.add_scalar("learning_rate",
|
writer.add_scalar("learning_rate",
|
||||||
optim._learning_rate.step().numpy()[0],
|
optim._learning_rate.step().numpy()[0],
|
||||||
global_step)
|
global_step)
|
||||||
optim.minimize(loss_var, grad_clip=clipper)
|
optim.minimize(loss_var, grad_clip=clipper)
|
||||||
optim.clear_gradients()
|
optim.clear_gradients()
|
||||||
print("loss: {:<8.6f}".format(loss_np[0]))
|
print("global_step: {}\tloss: {:<8.6f}".format(global_step,
|
||||||
|
loss_np[0]))
|
||||||
|
|
||||||
if global_step % snap_interval == 0:
|
if global_step % snap_interval == 0:
|
||||||
valid_model(model, valid_loader, writer, global_step,
|
valid_model(model, valid_loader, writer, global_step,
|
||||||
sample_rate)
|
sample_rate)
|
||||||
|
|
||||||
if global_step % checkpoint_interval == 0:
|
if global_step % checkpoint_interval == 0:
|
||||||
save_checkpoint(model, optim, checkpoint_dir, global_step)
|
io.save_parameters(checkpoint_dir, global_step, model, optim)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
|
@ -49,20 +49,14 @@ def valid_model(model, valid_loader, writer, global_step, sample_rate):
|
||||||
sample_rate)
|
sample_rate)
|
||||||
|
|
||||||
|
|
||||||
def eval_model(model, valid_loader, output_dir, sample_rate):
|
def eval_model(model, valid_loader, output_dir, global_step, sample_rate):
|
||||||
model.eval()
|
model.eval()
|
||||||
for i, batch in enumerate(valid_loader):
|
for i, batch in enumerate(valid_loader):
|
||||||
# print("sentence {}".format(i))
|
# print("sentence {}".format(i))
|
||||||
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
|
path = os.path.join(output_dir,
|
||||||
|
"sentence_{}_step_{}.wav".format(i, global_step))
|
||||||
audio_clips, mel_specs, audio_starts = batch
|
audio_clips, mel_specs, audio_starts = batch
|
||||||
wav_var = model.synthesis(mel_specs)
|
wav_var = model.synthesis(mel_specs)
|
||||||
wav_np = wav_var.numpy()[0]
|
wav_np = wav_var.numpy()[0]
|
||||||
sf.write(path, wav_np, samplerate=sample_rate)
|
sf.write(path, wav_np, samplerate=sample_rate)
|
||||||
print("generated {}".format(path))
|
print("generated {}".format(path))
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optim, checkpoint_dir, global_step):
|
|
||||||
checkpoint_path = os.path.join(checkpoint_dir,
|
|
||||||
"step_{:09d}".format(global_step))
|
|
||||||
dg.save_dygraph(model.state_dict(), checkpoint_path)
|
|
||||||
dg.save_dygraph(optim.state_dict(), checkpoint_path)
|
|
||||||
|
|
|
@ -52,8 +52,6 @@ def _load_latest_checkpoint(checkpoint_dir):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
rank (int, optional): the rank of the process in multi-process setting.
|
|
||||||
Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: the latest iteration number.
|
int: the latest iteration number.
|
||||||
|
@ -115,7 +113,6 @@ def load_parameters(model,
|
||||||
if iteration is None:
|
if iteration is None:
|
||||||
iteration = _load_latest_checkpoint(checkpoint_dir)
|
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||||
if iteration == 0:
|
if iteration == 0:
|
||||||
# if step-0 exist, it is also loaded
|
|
||||||
return iteration
|
return iteration
|
||||||
checkpoint_path = os.path.join(checkpoint_dir,
|
checkpoint_path = os.path.join(checkpoint_dir,
|
||||||
"step-{}".format(iteration))
|
"step-{}".format(iteration))
|
||||||
|
|
Loading…
Reference in New Issue