refactor for deep voice 3, update wavenet and clarinet to use enable_dygraph
This commit is contained in:
parent
a4dd5acc2f
commit
6aac18278e
|
@ -25,6 +25,7 @@ from tensorboardX import SummaryWriter
|
|||
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
fluid.require_version('1.8.0')
|
||||
|
||||
from parakeet.modules.weight_norm import WeightNormWrapper
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||
|
@ -64,6 +65,13 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
dg.enable_dygraph(place)
|
||||
|
||||
ljspeech_meta = LJSpeechMetaData(args.data)
|
||||
|
||||
data_config = config["data"]
|
||||
|
@ -105,75 +113,68 @@ if __name__ == "__main__":
|
|||
batch_size=1,
|
||||
sampler=SequentialSampler(ljspeech_valid))
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
# conditioner(upsampling net)
|
||||
conditioner_config = config["conditioner"]
|
||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
||||
freeze(upsample_net)
|
||||
|
||||
residual_channels = teacher_config["residual_channels"]
|
||||
loss_type = teacher_config["loss_type"]
|
||||
output_dim = teacher_config["output_dim"]
|
||||
log_scale_min = teacher_config["log_scale_min"]
|
||||
assert loss_type == "mog" and output_dim == 3, \
|
||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
||||
|
||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels,
|
||||
filter_size, loss_type, log_scale_min)
|
||||
# load & freeze upsample_net & teacher
|
||||
freeze(teacher)
|
||||
|
||||
student_config = config["student"]
|
||||
n_loops = student_config["n_loops"]
|
||||
n_layers = student_config["n_layers"]
|
||||
student_residual_channels = student_config["residual_channels"]
|
||||
student_filter_size = student_config["filter_size"]
|
||||
student_log_scale_min = student_config["log_scale_min"]
|
||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
||||
n_mels, student_filter_size)
|
||||
|
||||
stft_config = config["stft"]
|
||||
stft = STFT(
|
||||
n_fft=stft_config["n_fft"],
|
||||
hop_length=stft_config["hop_length"],
|
||||
win_length=stft_config["win_length"])
|
||||
|
||||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
summary(model)
|
||||
|
||||
# load parameters
|
||||
if args.checkpoint is not None:
|
||||
# load from args.checkpoint
|
||||
iteration = io.load_parameters(model, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
# load from "args.output/checkpoints"
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||
assert iteration > 0, "A trained checkpoint is needed."
|
||||
|
||||
with dg.guard(place):
|
||||
# conditioner(upsampling net)
|
||||
conditioner_config = config["conditioner"]
|
||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
||||
freeze(upsample_net)
|
||||
# make generation fast
|
||||
for sublayer in model.sublayers():
|
||||
if isinstance(sublayer, WeightNormWrapper):
|
||||
sublayer.remove_weight_norm()
|
||||
|
||||
residual_channels = teacher_config["residual_channels"]
|
||||
loss_type = teacher_config["loss_type"]
|
||||
output_dim = teacher_config["output_dim"]
|
||||
log_scale_min = teacher_config["log_scale_min"]
|
||||
assert loss_type == "mog" and output_dim == 3, \
|
||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
||||
# data loader
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
||||
n_mels, filter_size, loss_type, log_scale_min)
|
||||
# load & freeze upsample_net & teacher
|
||||
freeze(teacher)
|
||||
# the directory to save audio files
|
||||
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||
if not os.path.exists(synthesis_dir):
|
||||
os.makedirs(synthesis_dir)
|
||||
|
||||
student_config = config["student"]
|
||||
n_loops = student_config["n_loops"]
|
||||
n_layers = student_config["n_layers"]
|
||||
student_residual_channels = student_config["residual_channels"]
|
||||
student_filter_size = student_config["filter_size"]
|
||||
student_log_scale_min = student_config["log_scale_min"]
|
||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
||||
n_mels, student_filter_size)
|
||||
|
||||
stft_config = config["stft"]
|
||||
stft = STFT(
|
||||
n_fft=stft_config["n_fft"],
|
||||
hop_length=stft_config["hop_length"],
|
||||
win_length=stft_config["win_length"])
|
||||
|
||||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
summary(model)
|
||||
|
||||
# load parameters
|
||||
if args.checkpoint is not None:
|
||||
# load from args.checkpoint
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
# load from "args.output/checkpoints"
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||
assert iteration > 0, "A trained checkpoint is needed."
|
||||
|
||||
# make generation fast
|
||||
for sublayer in model.sublayers():
|
||||
if isinstance(sublayer, WeightNormWrapper):
|
||||
sublayer.remove_weight_norm()
|
||||
|
||||
# data loader
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
# the directory to save audio files
|
||||
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||
if not os.path.exists(synthesis_dir):
|
||||
os.makedirs(synthesis_dir)
|
||||
|
||||
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
|
||||
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
|
||||
|
|
|
@ -25,10 +25,11 @@ from tensorboardX import SummaryWriter
|
|||
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
fluid.require_version('1.8.0')
|
||||
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||
from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo
|
||||
from parakeet.data import TransformDataset, SliceDataset, CacheDataset, RandomSampler, SequentialSampler, DataCargo
|
||||
from parakeet.utils.layer_tools import summary, freeze
|
||||
from parakeet.utils import io
|
||||
|
||||
|
@ -66,6 +67,13 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
dg.enable_dygraph(place)
|
||||
|
||||
print("Command Line args: ")
|
||||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
@ -83,8 +91,9 @@ if __name__ == "__main__":
|
|||
ljspeech = TransformDataset(ljspeech_meta, transform)
|
||||
|
||||
valid_size = data_config["valid_size"]
|
||||
ljspeech_valid = SliceDataset(ljspeech, 0, valid_size)
|
||||
ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech))
|
||||
ljspeech_valid = CacheDataset(SliceDataset(ljspeech, 0, valid_size))
|
||||
ljspeech_train = CacheDataset(
|
||||
SliceDataset(ljspeech, valid_size, len(ljspeech)))
|
||||
|
||||
teacher_config = config["teacher"]
|
||||
n_loop = teacher_config["n_loop"]
|
||||
|
@ -113,130 +122,122 @@ if __name__ == "__main__":
|
|||
|
||||
make_output_tree(args.output)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
# conditioner(upsampling net)
|
||||
conditioner_config = config["conditioner"]
|
||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
||||
freeze(upsample_net)
|
||||
|
||||
residual_channels = teacher_config["residual_channels"]
|
||||
loss_type = teacher_config["loss_type"]
|
||||
output_dim = teacher_config["output_dim"]
|
||||
log_scale_min = teacher_config["log_scale_min"]
|
||||
assert loss_type == "mog" and output_dim == 3, \
|
||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
||||
|
||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels,
|
||||
filter_size, loss_type, log_scale_min)
|
||||
freeze(teacher)
|
||||
|
||||
student_config = config["student"]
|
||||
n_loops = student_config["n_loops"]
|
||||
n_layers = student_config["n_layers"]
|
||||
student_residual_channels = student_config["residual_channels"]
|
||||
student_filter_size = student_config["filter_size"]
|
||||
student_log_scale_min = student_config["log_scale_min"]
|
||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
||||
n_mels, student_filter_size)
|
||||
|
||||
stft_config = config["stft"]
|
||||
stft = STFT(
|
||||
n_fft=stft_config["n_fft"],
|
||||
hop_length=stft_config["hop_length"],
|
||||
win_length=stft_config["win_length"])
|
||||
|
||||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
summary(model)
|
||||
|
||||
# optim
|
||||
train_config = config["train"]
|
||||
learning_rate = train_config["learning_rate"]
|
||||
anneal_rate = train_config["anneal_rate"]
|
||||
anneal_interval = train_config["anneal_interval"]
|
||||
lr_scheduler = dg.ExponentialDecay(
|
||||
learning_rate, anneal_interval, anneal_rate, staircase=True)
|
||||
gradiant_max_norm = train_config["gradient_max_norm"]
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.ClipByGlobalNorm(gradiant_max_norm))
|
||||
|
||||
# train
|
||||
max_iterations = train_config["max_iterations"]
|
||||
checkpoint_interval = train_config["checkpoint_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
state_dir = os.path.join(args.output, "states")
|
||||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(
|
||||
model, optim, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
iteration = io.load_parameters(
|
||||
model,
|
||||
optim,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
iteration=args.iteration)
|
||||
|
||||
with dg.guard(place):
|
||||
# conditioner(upsampling net)
|
||||
conditioner_config = config["conditioner"]
|
||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
||||
freeze(upsample_net)
|
||||
if iteration == 0:
|
||||
assert args.wavenet is not None, "When training afresh, a trained wavenet model should be provided."
|
||||
load_wavenet(model, args.wavenet)
|
||||
|
||||
residual_channels = teacher_config["residual_channels"]
|
||||
loss_type = teacher_config["loss_type"]
|
||||
output_dim = teacher_config["output_dim"]
|
||||
log_scale_min = teacher_config["log_scale_min"]
|
||||
assert loss_type == "mog" and output_dim == 3, \
|
||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
||||
# loader
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
|
||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
||||
n_mels, filter_size, loss_type, log_scale_min)
|
||||
freeze(teacher)
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
student_config = config["student"]
|
||||
n_loops = student_config["n_loops"]
|
||||
n_layers = student_config["n_layers"]
|
||||
student_residual_channels = student_config["residual_channels"]
|
||||
student_filter_size = student_config["filter_size"]
|
||||
student_log_scale_min = student_config["log_scale_min"]
|
||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
||||
n_mels, student_filter_size)
|
||||
# training loop
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm(train_loader))
|
||||
while global_step <= max_iterations:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm(train_loader))
|
||||
batch = next(iterator)
|
||||
|
||||
stft_config = config["stft"]
|
||||
stft = STFT(
|
||||
n_fft=stft_config["n_fft"],
|
||||
hop_length=stft_config["hop_length"],
|
||||
win_length=stft_config["win_length"])
|
||||
audios, mels, audio_starts = batch
|
||||
model.train()
|
||||
loss_dict = model(
|
||||
audios, mels, audio_starts, clip_kl=global_step > 500)
|
||||
|
||||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
summary(model)
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy()[0], global_step)
|
||||
for k, v in loss_dict.items():
|
||||
writer.add_scalar("loss/{}".format(k), v.numpy()[0], global_step)
|
||||
|
||||
# optim
|
||||
train_config = config["train"]
|
||||
learning_rate = train_config["learning_rate"]
|
||||
anneal_rate = train_config["anneal_rate"]
|
||||
anneal_interval = train_config["anneal_interval"]
|
||||
lr_scheduler = dg.ExponentialDecay(
|
||||
learning_rate, anneal_interval, anneal_rate, staircase=True)
|
||||
gradiant_max_norm = train_config["gradient_max_norm"]
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.ClipByGlobalNorm(gradiant_max_norm))
|
||||
l = loss_dict["loss"]
|
||||
step_loss = l.numpy()[0]
|
||||
print("[train] global_step: {} loss: {:<8.6f}".format(global_step,
|
||||
step_loss))
|
||||
|
||||
# train
|
||||
max_iterations = train_config["max_iterations"]
|
||||
checkpoint_interval = train_config["checkpoint_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
state_dir = os.path.join(args.output, "states")
|
||||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
l.backward()
|
||||
optim.minimize(l)
|
||||
optim.clear_gradients()
|
||||
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(
|
||||
model, optim, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
model,
|
||||
optim,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
iteration=args.iteration)
|
||||
if global_step % eval_interval == 0:
|
||||
# evaluate on valid dataset
|
||||
eval_model(model, valid_loader, state_dir, global_step,
|
||||
sample_rate)
|
||||
if global_step % checkpoint_interval == 0:
|
||||
io.save_parameters(checkpoint_dir, global_step, model, optim)
|
||||
|
||||
if iteration == 0:
|
||||
assert args.wavenet is not None, "When training afresh, a trained wavenet model should be provided."
|
||||
load_wavenet(model, args.wavenet)
|
||||
|
||||
# loader
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
# training loop
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm(train_loader))
|
||||
while global_step <= max_iterations:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm(train_loader))
|
||||
batch = next(iterator)
|
||||
|
||||
audios, mels, audio_starts = batch
|
||||
model.train()
|
||||
loss_dict = model(
|
||||
audios, mels, audio_starts, clip_kl=global_step > 500)
|
||||
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy()[0],
|
||||
global_step)
|
||||
for k, v in loss_dict.items():
|
||||
writer.add_scalar("loss/{}".format(k),
|
||||
v.numpy()[0], global_step)
|
||||
|
||||
l = loss_dict["loss"]
|
||||
step_loss = l.numpy()[0]
|
||||
print("[train] global_step: {} loss: {:<8.6f}".format(global_step,
|
||||
step_loss))
|
||||
|
||||
l.backward()
|
||||
optim.minimize(l)
|
||||
optim.clear_gradients()
|
||||
|
||||
if global_step % eval_interval == 0:
|
||||
# evaluate on valid dataset
|
||||
eval_model(model, valid_loader, state_dir, global_step,
|
||||
sample_rate)
|
||||
if global_step % checkpoint_interval == 0:
|
||||
io.save_parameters(checkpoint_dir, global_step, model, optim)
|
||||
|
||||
global_step += 1
|
||||
global_step += 1
|
||||
|
|
|
@ -23,6 +23,7 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
|
|||
|
||||
```text
|
||||
├── data.py data_processing
|
||||
├── model.py function to create model, criterion and optimizer
|
||||
├── configs/ (example) configuration files
|
||||
├── sentences.txt sample sentences
|
||||
├── synthesis.py script to synthesize waveform from text
|
||||
|
@ -34,19 +35,20 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed
|
|||
`train.py` and `synthesis.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
|
||||
|
||||
1. `output` is the directory for saving results.
|
||||
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. Other possible outputs are saved in `states/` in `outuput`.
|
||||
During synthesizing, audio files and other possible outputs are save in `synthesis/` in `output`.
|
||||
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`. States for training including alignment plots, spectrogram plots and generated audio files are saved in `states/` in `outuput`. In addition, we periodically evaluate the model with several given sentences, the alignment plots and generated audio files are save in `eval/` in `output`.
|
||||
During synthesizing, audio files and the alignment plots are save in `synthesis/` in `output`.
|
||||
So after training and synthesizing with the same output directory, the file structure of the output directory looks like this.
|
||||
|
||||
```text
|
||||
├── checkpoints/ # checkpoint directory (including *.pdparams, *.pdopt and a text file `checkpoint` that records the latest checkpoint)
|
||||
├── states/ # audio files generated at validation and other possible outputs
|
||||
├── states/ # alignment plots, spectrogram plots and generated wavs at training
|
||||
├── log/ # tensorboard log
|
||||
└── synthesis/ # synthesized audio files and other possible outputs
|
||||
├── eval/ # audio files an alignment plots generated at evaluation during training
|
||||
└── synthesis/ # synthesized audio files and alignment plots
|
||||
```
|
||||
|
||||
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
|
||||
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||
If `--checkpoint` is provided, the path of the checkpoint specified by `--checkpoint` is loaded.
|
||||
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
|
||||
|
||||
## Train
|
||||
|
@ -100,6 +102,18 @@ python train.py \
|
|||
experiment
|
||||
```
|
||||
|
||||
To train the model in a paralle in multiple gpus, you can launch the training script with `paddle.distributed.launch`. For example, to train with gpu `0,1,2,3`, you can use the example script below. Note that for parallel training, devices are specified with `--selected_gpus` passed to `paddle.distributed.launch`. In this case, `--device` passed to `train.py`, if specified, is ignored.
|
||||
|
||||
Example script:
|
||||
|
||||
```bash
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 \
|
||||
train.py \
|
||||
--config=configs/ljspeech.yaml \
|
||||
--data=./LJSpeech-1.1/ \
|
||||
experiment
|
||||
```
|
||||
|
||||
You can monitor training log via tensorboard, using the script below.
|
||||
|
||||
```bash
|
||||
|
|
|
@ -17,13 +17,16 @@ import os
|
|||
import csv
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from paddle import fluid
|
||||
import pandas as pd
|
||||
import librosa
|
||||
from scipy import signal, io
|
||||
import six
|
||||
from scipy import signal
|
||||
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
from parakeet.data import DatasetMixin, TransformDataset, FilterDataset
|
||||
from parakeet.g2p.en import text_to_sequence, sequence_to_text
|
||||
from parakeet.data import DatasetMixin, TransformDataset, FilterDataset, CacheDataset
|
||||
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler, BucketSampler
|
||||
|
||||
|
||||
class LJSpeechMetaData(DatasetMixin):
|
||||
|
@ -50,7 +53,7 @@ class LJSpeechMetaData(DatasetMixin):
|
|||
|
||||
class Transform(object):
|
||||
def __init__(self,
|
||||
replace_pronounciation_prob=0.,
|
||||
replace_pronunciation_prob=0.,
|
||||
sample_rate=22050,
|
||||
preemphasis=.97,
|
||||
n_fft=1024,
|
||||
|
@ -63,7 +66,7 @@ class Transform(object):
|
|||
ref_level_db=20,
|
||||
max_norm=0.999,
|
||||
clip_norm=True):
|
||||
self.replace_pronounciation_prob = replace_pronounciation_prob
|
||||
self.replace_pronunciation_prob = replace_pronunciation_prob
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.preemphasis = preemphasis
|
||||
|
@ -85,7 +88,7 @@ class Transform(object):
|
|||
|
||||
# text processing
|
||||
mix_grapheme_phonemes = text_to_sequence(
|
||||
normalized_text, self.replace_pronounciation_prob)
|
||||
normalized_text, self.replace_pronunciation_prob)
|
||||
text_length = len(mix_grapheme_phonemes)
|
||||
# CAUTION: positions start from 1
|
||||
speaker_id = None
|
||||
|
@ -125,8 +128,8 @@ class Transform(object):
|
|||
|
||||
# num_frames
|
||||
n_frames = S_mel_norm.shape[-1] # CAUTION: original number of frames
|
||||
return (mix_grapheme_phonemes, text_length, speaker_id, S_norm,
|
||||
S_mel_norm, n_frames)
|
||||
return (mix_grapheme_phonemes, text_length, speaker_id, S_norm.T,
|
||||
S_mel_norm.T, n_frames)
|
||||
|
||||
|
||||
class DataCollector(object):
|
||||
|
@ -166,12 +169,12 @@ class DataCollector(object):
|
|||
),
|
||||
mode="constant"))
|
||||
lin_specs.append(
|
||||
np.pad(S_norm, ((0, 0), (self._pad_begin, max_frames -
|
||||
self._pad_begin - num_frames)),
|
||||
np.pad(S_norm, ((self._pad_begin, max_frames - self._pad_begin
|
||||
- num_frames), (0, 0)),
|
||||
mode="constant"))
|
||||
mel_specs.append(
|
||||
np.pad(S_mel_norm, ((0, 0), (self._pad_begin, max_frames -
|
||||
self._pad_begin - num_frames)),
|
||||
np.pad(S_mel_norm, ((self._pad_begin, max_frames -
|
||||
self._pad_begin - num_frames), (0, 0)),
|
||||
mode="constant"))
|
||||
done_flags.append(
|
||||
np.pad(np.zeros((int(np.ceil(num_frames // self._factor)), )),
|
||||
|
@ -180,10 +183,10 @@ class DataCollector(object):
|
|||
mode="constant",
|
||||
constant_values=1))
|
||||
text_sequences = np.array(text_sequences).astype(np.int64)
|
||||
lin_specs = np.transpose(np.array(lin_specs),
|
||||
(0, 2, 1)).astype(np.float32)
|
||||
mel_specs = np.transpose(np.array(mel_specs),
|
||||
(0, 2, 1)).astype(np.float32)
|
||||
lin_specs = np.array(lin_specs).astype(np.float32)
|
||||
mel_specs = np.array(mel_specs).astype(np.float32)
|
||||
|
||||
# downsample here
|
||||
done_flags = np.array(done_flags).astype(np.float32)
|
||||
|
||||
# text positions
|
||||
|
@ -201,3 +204,54 @@ class DataCollector(object):
|
|||
|
||||
return (text_sequences, text_lengths, text_positions, mel_specs,
|
||||
lin_specs, frames, decoder_positions, done_flags)
|
||||
|
||||
|
||||
def make_data_loader(data_root, config):
|
||||
# construct meta data
|
||||
meta = LJSpeechMetaData(data_root)
|
||||
|
||||
# filter it!
|
||||
min_text_length = config["meta_data"]["min_text_length"]
|
||||
meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length)
|
||||
|
||||
# transform meta data into meta data
|
||||
c = config["transform"]
|
||||
transform = Transform(
|
||||
replace_pronunciation_prob=c["replace_pronunciation_prob"],
|
||||
sample_rate=c["sample_rate"],
|
||||
preemphasis=c["preemphasis"],
|
||||
n_fft=c["n_fft"],
|
||||
win_length=c["win_length"],
|
||||
hop_length=c["hop_length"],
|
||||
fmin=c["fmin"],
|
||||
fmax=c["fmax"],
|
||||
n_mels=c["n_mels"],
|
||||
min_level_db=c["min_level_db"],
|
||||
ref_level_db=c["ref_level_db"],
|
||||
max_norm=c["max_norm"],
|
||||
clip_norm=c["clip_norm"])
|
||||
ljspeech = CacheDataset(TransformDataset(meta, transform))
|
||||
|
||||
# use meta data's text length as a sort key for the sampler
|
||||
batch_size = config["train"]["batch_size"]
|
||||
text_lengths = [len(example[2]) for example in meta]
|
||||
sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths,
|
||||
batch_size)
|
||||
|
||||
env = dg.parallel.ParallelEnv()
|
||||
num_trainers = env.nranks
|
||||
local_rank = env.local_rank
|
||||
sampler = BucketSampler(
|
||||
text_lengths, batch_size, num_trainers=num_trainers, rank=local_rank)
|
||||
|
||||
# some model hyperparameters affect how we process data
|
||||
model_config = config["model"]
|
||||
collector = DataCollector(
|
||||
downsample_factor=model_config["downsample_factor"],
|
||||
r=model_config["outputs_per_step"])
|
||||
ljspeech_loader = DataCargo(
|
||||
ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler)
|
||||
loader = fluid.io.DataLoader.from_generator(capacity=10, return_list=True)
|
||||
loader.set_batch_generator(
|
||||
ljspeech_loader, places=fluid.framework._current_expected_place())
|
||||
return loader
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import fluid
|
||||
import paddle.fluid.initializer as I
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
from parakeet.g2p import en
|
||||
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, TTSLoss, ConvSpec, WindowRange
|
||||
from parakeet.utils.layer_tools import summary, freeze
|
||||
|
||||
|
||||
def make_model(config):
|
||||
c = config["model"]
|
||||
# speaker embedding
|
||||
n_speakers = c["n_speakers"]
|
||||
speaker_dim = c["speaker_embed_dim"]
|
||||
if n_speakers > 1:
|
||||
speaker_embed = dg.Embedding(
|
||||
(n_speakers, speaker_dim),
|
||||
param_attr=I.Normal(scale=c["speaker_embedding_weight_std"]))
|
||||
else:
|
||||
speaker_embed = None
|
||||
|
||||
# encoder
|
||||
h = c["encoder_channels"]
|
||||
k = c["kernel_size"]
|
||||
encoder_convolutions = (
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(h, k, 9),
|
||||
ConvSpec(h, k, 27),
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(h, k, 9),
|
||||
ConvSpec(h, k, 27),
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3), )
|
||||
encoder = Encoder(
|
||||
n_vocab=en.n_vocab,
|
||||
embed_dim=c["text_embed_dim"],
|
||||
n_speakers=n_speakers,
|
||||
speaker_dim=speaker_dim,
|
||||
embedding_weight_std=c["embedding_weight_std"],
|
||||
convolutions=encoder_convolutions,
|
||||
dropout=c["dropout"])
|
||||
if c["freeze_embedding"]:
|
||||
freeze(encoder.embed)
|
||||
|
||||
# decoder
|
||||
h = c["decoder_channels"]
|
||||
k = c["kernel_size"]
|
||||
prenet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3))
|
||||
attentive_convolutions = (
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(h, k, 9),
|
||||
ConvSpec(h, k, 27),
|
||||
ConvSpec(h, k, 1), )
|
||||
attention = [True, False, False, False, True]
|
||||
force_monotonic_attention = [True, False, False, False, True]
|
||||
window = WindowRange(c["window_backward"], c["window_ahead"])
|
||||
decoder = Decoder(
|
||||
n_speakers,
|
||||
speaker_dim,
|
||||
embed_dim=c["text_embed_dim"],
|
||||
mel_dim=config["transform"]["n_mels"],
|
||||
r=c["outputs_per_step"],
|
||||
max_positions=c["max_positions"],
|
||||
preattention=prenet_convolutions,
|
||||
convolutions=attentive_convolutions,
|
||||
attention=attention,
|
||||
dropout=c["dropout"],
|
||||
use_memory_mask=c["use_memory_mask"],
|
||||
force_monotonic_attention=force_monotonic_attention,
|
||||
query_position_rate=c["query_position_rate"],
|
||||
key_position_rate=c["key_position_rate"],
|
||||
window_range=window,
|
||||
key_projection=c["key_projection"],
|
||||
value_projection=c["value_projection"])
|
||||
if not c["trainable_positional_encodings"]:
|
||||
freeze(decoder.embed_keys_positions)
|
||||
freeze(decoder.embed_query_positions)
|
||||
|
||||
# converter(postnet)
|
||||
linear_dim = 1 + config["transform"]["n_fft"] // 2
|
||||
h = c["converter_channels"]
|
||||
k = c["kernel_size"]
|
||||
postnet_convolutions = (
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(2 * h, k, 1),
|
||||
ConvSpec(2 * h, k, 3), )
|
||||
use_decoder_states = c["use_decoder_state_for_postnet_input"]
|
||||
converter = Converter(
|
||||
n_speakers,
|
||||
speaker_dim,
|
||||
in_channels=decoder.state_dim
|
||||
if use_decoder_states else config["transform"]["n_mels"],
|
||||
linear_dim=linear_dim,
|
||||
time_upsampling=c["downsample_factor"],
|
||||
convolutions=postnet_convolutions,
|
||||
dropout=c["dropout"])
|
||||
|
||||
model = DeepVoice3(
|
||||
encoder,
|
||||
decoder,
|
||||
converter,
|
||||
speaker_embed,
|
||||
use_decoder_states=use_decoder_states)
|
||||
return model
|
||||
|
||||
|
||||
def make_criterion(config):
|
||||
# =========================loss=========================
|
||||
loss_config = config["loss"]
|
||||
transform_config = config["transform"]
|
||||
model_config = config["model"]
|
||||
|
||||
priority_freq = loss_config["priority_freq"] # Hz
|
||||
sample_rate = transform_config["sample_rate"]
|
||||
linear_dim = 1 + transform_config["n_fft"] // 2
|
||||
priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim)
|
||||
|
||||
criterion = TTSLoss(
|
||||
masked_weight=loss_config["masked_loss_weight"],
|
||||
priority_bin=priority_bin,
|
||||
priority_weight=loss_config["priority_freq_weight"],
|
||||
binary_divergence_weight=loss_config["binary_divergence_weight"],
|
||||
guided_attention_sigma=loss_config["guided_attention_sigma"],
|
||||
downsample_factor=model_config["downsample_factor"],
|
||||
r=model_config["outputs_per_step"])
|
||||
return criterion
|
||||
|
||||
|
||||
def make_optimizer(model, config):
|
||||
# =========================lr_scheduler=========================
|
||||
lr_config = config["lr_scheduler"]
|
||||
warmup_steps = lr_config["warmup_steps"]
|
||||
peak_learning_rate = lr_config["peak_learning_rate"]
|
||||
lr_scheduler = dg.NoamDecay(1 / (warmup_steps * (peak_learning_rate)**2),
|
||||
warmup_steps)
|
||||
|
||||
# =========================optimizer=========================
|
||||
optim_config = config["optimizer"]
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
beta1=optim_config["beta1"],
|
||||
beta2=optim_config["beta2"],
|
||||
epsilon=optim_config["epsilon"],
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1))
|
||||
return optim
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
import soundfile as sf
|
||||
|
||||
from paddle import fluid
|
||||
fluid.require_version('1.8.0')
|
||||
import paddle.fluid.layers as F
|
||||
import paddle.fluid.dygraph as dg
|
||||
from tensorboardX import SummaryWriter
|
||||
|
@ -29,7 +30,8 @@ from parakeet.modules.weight_norm import WeightNormWrapper
|
|||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.utils import io
|
||||
|
||||
from utils import make_model, eval_model, plot_alignment
|
||||
from model import make_model
|
||||
from utils import make_evaluator
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -61,101 +63,29 @@ if __name__ == "__main__":
|
|||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
with dg.guard(place):
|
||||
# =========================model=========================
|
||||
transform_config = config["transform"]
|
||||
replace_pronounciation_prob = transform_config[
|
||||
"replace_pronunciation_prob"]
|
||||
sample_rate = transform_config["sample_rate"]
|
||||
preemphasis = transform_config["preemphasis"]
|
||||
n_fft = transform_config["n_fft"]
|
||||
n_mels = transform_config["n_mels"]
|
||||
dg.enable_dygraph(place)
|
||||
|
||||
model_config = config["model"]
|
||||
downsample_factor = model_config["downsample_factor"]
|
||||
r = model_config["outputs_per_step"]
|
||||
n_speakers = model_config["n_speakers"]
|
||||
speaker_dim = model_config["speaker_embed_dim"]
|
||||
speaker_embed_std = model_config["speaker_embedding_weight_std"]
|
||||
n_vocab = en.n_vocab
|
||||
embed_dim = model_config["text_embed_dim"]
|
||||
linear_dim = 1 + n_fft // 2
|
||||
use_decoder_states = model_config[
|
||||
"use_decoder_state_for_postnet_input"]
|
||||
filter_size = model_config["kernel_size"]
|
||||
encoder_channels = model_config["encoder_channels"]
|
||||
decoder_channels = model_config["decoder_channels"]
|
||||
converter_channels = model_config["converter_channels"]
|
||||
dropout = model_config["dropout"]
|
||||
padding_idx = model_config["padding_idx"]
|
||||
embedding_std = model_config["embedding_weight_std"]
|
||||
max_positions = model_config["max_positions"]
|
||||
freeze_embedding = model_config["freeze_embedding"]
|
||||
trainable_positional_encodings = model_config[
|
||||
"trainable_positional_encodings"]
|
||||
use_memory_mask = model_config["use_memory_mask"]
|
||||
query_position_rate = model_config["query_position_rate"]
|
||||
key_position_rate = model_config["key_position_rate"]
|
||||
window_backward = model_config["window_backward"]
|
||||
window_ahead = model_config["window_ahead"]
|
||||
key_projection = model_config["key_projection"]
|
||||
value_projection = model_config["value_projection"]
|
||||
dv3 = make_model(
|
||||
n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx,
|
||||
embedding_std, max_positions, n_vocab, freeze_embedding,
|
||||
filter_size, encoder_channels, n_mels, decoder_channels, r,
|
||||
trainable_positional_encodings, use_memory_mask,
|
||||
query_position_rate, key_position_rate, window_backward,
|
||||
window_ahead, key_projection, value_projection, downsample_factor,
|
||||
linear_dim, use_decoder_states, converter_channels, dropout)
|
||||
model = make_model(config)
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(model, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||
|
||||
summary(dv3)
|
||||
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
||||
# removing weight norm also speeds up computation
|
||||
for layer in model.sublayers():
|
||||
if isinstance(layer, WeightNormWrapper):
|
||||
layer.remove_weight_norm()
|
||||
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(
|
||||
dv3, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
dv3, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||
if not os.path.exists(synthesis_dir):
|
||||
os.makedirs(synthesis_dir)
|
||||
|
||||
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
||||
# removing weight norm also speeds up computation
|
||||
for layer in dv3.sublayers():
|
||||
if isinstance(layer, WeightNormWrapper):
|
||||
layer.remove_weight_norm()
|
||||
with open(args.text, "rt", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
sentences = [line[:-1] for line in lines]
|
||||
|
||||
transform_config = config["transform"]
|
||||
c = transform_config["replace_pronunciation_prob"]
|
||||
sample_rate = transform_config["sample_rate"]
|
||||
min_level_db = transform_config["min_level_db"]
|
||||
ref_level_db = transform_config["ref_level_db"]
|
||||
preemphasis = transform_config["preemphasis"]
|
||||
win_length = transform_config["win_length"]
|
||||
hop_length = transform_config["hop_length"]
|
||||
|
||||
synthesis_config = config["synthesis"]
|
||||
power = synthesis_config["power"]
|
||||
n_iter = synthesis_config["n_iter"]
|
||||
|
||||
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||
if not os.path.exists(synthesis_dir):
|
||||
os.makedirs(synthesis_dir)
|
||||
|
||||
with open(args.text, "rt", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
for idx, line in enumerate(lines):
|
||||
text = line[:-1]
|
||||
dv3.eval()
|
||||
wav, attn = eval_model(dv3, text, replace_pronounciation_prob,
|
||||
min_level_db, ref_level_db, power,
|
||||
n_iter, win_length, hop_length,
|
||||
preemphasis)
|
||||
plot_alignment(
|
||||
attn,
|
||||
os.path.join(synthesis_dir,
|
||||
"test_{}_step_{}.png".format(idx, iteration)))
|
||||
sf.write(
|
||||
os.path.join(synthesis_dir,
|
||||
"test_{}_step{}.wav".format(idx, iteration)),
|
||||
wav, sample_rate)
|
||||
evaluator = make_evaluator(config, sentences, synthesis_dir)
|
||||
evaluator(model, iteration)
|
||||
|
|
|
@ -13,57 +13,37 @@
|
|||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
import time
|
||||
import os
|
||||
import argparse
|
||||
import ruamel.yaml
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("agg")
|
||||
from matplotlib import cm
|
||||
import matplotlib.pyplot as plt
|
||||
import tqdm
|
||||
import librosa
|
||||
from librosa import display
|
||||
import soundfile as sf
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from paddle import fluid
|
||||
fluid.require_version('1.8.0')
|
||||
import paddle.fluid.layers as F
|
||||
import paddle.fluid.dygraph as dg
|
||||
from parakeet.utils.io import load_parameters, save_parameters
|
||||
|
||||
from parakeet.g2p import en
|
||||
from parakeet.data import FilterDataset, TransformDataset, FilterDataset
|
||||
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
||||
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
|
||||
from parakeet.models.deepvoice3.loss import TTSLoss
|
||||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.utils import io
|
||||
|
||||
from data import LJSpeechMetaData, DataCollector, Transform
|
||||
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
|
||||
from data import make_data_loader
|
||||
from model import make_model, make_criterion, make_optimizer
|
||||
from utils import make_output_tree, add_options, get_place, Evaluator, StateSaver, make_evaluator, make_state_saver
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train a Deep Voice 3 model with LJSpeech dataset.")
|
||||
parser.add_argument("--config", type=str, help="experimrnt config")
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="/workspace/datasets/LJSpeech-1.1/",
|
||||
help="The path of the LJSpeech dataset.")
|
||||
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"output", type=str, default="experiment", help="path to save results")
|
||||
|
||||
add_options(parser)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# only use args.device when training in single process
|
||||
# when training with distributed.launch, devices are provided by
|
||||
# `--selected_gpus` for distributed.launch
|
||||
env = dg.parallel.ParallelEnv()
|
||||
device_id = env.dev_id if env.nranks > 1 else args.device
|
||||
place = get_place(device_id)
|
||||
# start dygraph
|
||||
dg.enable_dygraph(place)
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
|
@ -71,267 +51,122 @@ if __name__ == "__main__":
|
|||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
||||
# =========================dataset=========================
|
||||
# construct meta data
|
||||
data_root = args.data
|
||||
meta = LJSpeechMetaData(data_root)
|
||||
data_loader = make_data_loader(args.data, config)
|
||||
model = make_model(config)
|
||||
if env.nranks > 1:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = dg.DataParallel(model, strategy)
|
||||
criterion = make_criterion(config)
|
||||
optim = make_optimizer(model, config)
|
||||
|
||||
# filter it!
|
||||
min_text_length = config["meta_data"]["min_text_length"]
|
||||
meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length)
|
||||
# generation
|
||||
synthesis_config = config["synthesis"]
|
||||
power = synthesis_config["power"]
|
||||
n_iter = synthesis_config["n_iter"]
|
||||
|
||||
# transform meta data into meta data
|
||||
transform_config = config["transform"]
|
||||
replace_pronounciation_prob = transform_config[
|
||||
"replace_pronunciation_prob"]
|
||||
sample_rate = transform_config["sample_rate"]
|
||||
preemphasis = transform_config["preemphasis"]
|
||||
n_fft = transform_config["n_fft"]
|
||||
win_length = transform_config["win_length"]
|
||||
hop_length = transform_config["hop_length"]
|
||||
fmin = transform_config["fmin"]
|
||||
fmax = transform_config["fmax"]
|
||||
n_mels = transform_config["n_mels"]
|
||||
min_level_db = transform_config["min_level_db"]
|
||||
ref_level_db = transform_config["ref_level_db"]
|
||||
max_norm = transform_config["max_norm"]
|
||||
clip_norm = transform_config["clip_norm"]
|
||||
transform = Transform(replace_pronounciation_prob, sample_rate,
|
||||
preemphasis, n_fft, win_length, hop_length, fmin,
|
||||
fmax, n_mels, min_level_db, ref_level_db, max_norm,
|
||||
clip_norm)
|
||||
ljspeech = TransformDataset(meta, transform)
|
||||
|
||||
# =========================dataiterator=========================
|
||||
# use meta data's text length as a sort key for the sampler
|
||||
train_config = config["train"]
|
||||
batch_size = train_config["batch_size"]
|
||||
text_lengths = [len(example[2]) for example in meta]
|
||||
sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths,
|
||||
batch_size)
|
||||
|
||||
# some hyperparameters affect how we process data, so create a data collector!
|
||||
model_config = config["model"]
|
||||
downsample_factor = model_config["downsample_factor"]
|
||||
r = model_config["outputs_per_step"]
|
||||
collector = DataCollector(downsample_factor=downsample_factor, r=r)
|
||||
ljspeech_loader = DataCargo(
|
||||
ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler)
|
||||
|
||||
# =========================model=========================
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
with dg.guard(place):
|
||||
# =========================model=========================
|
||||
n_speakers = model_config["n_speakers"]
|
||||
speaker_dim = model_config["speaker_embed_dim"]
|
||||
speaker_embed_std = model_config["speaker_embedding_weight_std"]
|
||||
n_vocab = en.n_vocab
|
||||
embed_dim = model_config["text_embed_dim"]
|
||||
linear_dim = 1 + n_fft // 2
|
||||
use_decoder_states = model_config[
|
||||
"use_decoder_state_for_postnet_input"]
|
||||
filter_size = model_config["kernel_size"]
|
||||
encoder_channels = model_config["encoder_channels"]
|
||||
decoder_channels = model_config["decoder_channels"]
|
||||
converter_channels = model_config["converter_channels"]
|
||||
dropout = model_config["dropout"]
|
||||
padding_idx = model_config["padding_idx"]
|
||||
embedding_std = model_config["embedding_weight_std"]
|
||||
max_positions = model_config["max_positions"]
|
||||
freeze_embedding = model_config["freeze_embedding"]
|
||||
trainable_positional_encodings = model_config[
|
||||
"trainable_positional_encodings"]
|
||||
use_memory_mask = model_config["use_memory_mask"]
|
||||
query_position_rate = model_config["query_position_rate"]
|
||||
key_position_rate = model_config["key_position_rate"]
|
||||
window_backward = model_config["window_backward"]
|
||||
window_ahead = model_config["window_ahead"]
|
||||
key_projection = model_config["key_projection"]
|
||||
value_projection = model_config["value_projection"]
|
||||
dv3 = make_model(
|
||||
n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx,
|
||||
embedding_std, max_positions, n_vocab, freeze_embedding,
|
||||
filter_size, encoder_channels, n_mels, decoder_channels, r,
|
||||
trainable_positional_encodings, use_memory_mask,
|
||||
query_position_rate, key_position_rate, window_backward,
|
||||
window_ahead, key_projection, value_projection, downsample_factor,
|
||||
linear_dim, use_decoder_states, converter_channels, dropout)
|
||||
summary(dv3)
|
||||
|
||||
# =========================loss=========================
|
||||
loss_config = config["loss"]
|
||||
masked_weight = loss_config["masked_loss_weight"]
|
||||
priority_freq = loss_config["priority_freq"] # Hz
|
||||
priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim)
|
||||
priority_freq_weight = loss_config["priority_freq_weight"]
|
||||
binary_divergence_weight = loss_config["binary_divergence_weight"]
|
||||
guided_attention_sigma = loss_config["guided_attention_sigma"]
|
||||
criterion = TTSLoss(
|
||||
masked_weight=masked_weight,
|
||||
priority_bin=priority_bin,
|
||||
priority_weight=priority_freq_weight,
|
||||
binary_divergence_weight=binary_divergence_weight,
|
||||
guided_attention_sigma=guided_attention_sigma,
|
||||
downsample_factor=downsample_factor,
|
||||
r=r)
|
||||
|
||||
# =========================lr_scheduler=========================
|
||||
lr_config = config["lr_scheduler"]
|
||||
warmup_steps = lr_config["warmup_steps"]
|
||||
peak_learning_rate = lr_config["peak_learning_rate"]
|
||||
lr_scheduler = dg.NoamDecay(
|
||||
1 / (warmup_steps * (peak_learning_rate)**2), warmup_steps)
|
||||
|
||||
# =========================optimizer=========================
|
||||
optim_config = config["optimizer"]
|
||||
beta1 = optim_config["beta1"]
|
||||
beta2 = optim_config["beta2"]
|
||||
epsilon = optim_config["epsilon"]
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
beta1,
|
||||
beta2,
|
||||
epsilon=epsilon,
|
||||
parameter_list=dv3.parameters(),
|
||||
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1))
|
||||
|
||||
# generation
|
||||
synthesis_config = config["synthesis"]
|
||||
power = synthesis_config["power"]
|
||||
n_iter = synthesis_config["n_iter"]
|
||||
|
||||
# =========================link(dataloader, paddle)=========================
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
loader.set_batch_generator(ljspeech_loader, places=place)
|
||||
|
||||
# tensorboard & checkpoint preparation
|
||||
output_dir = args.output
|
||||
ckpt_dir = os.path.join(output_dir, "checkpoints")
|
||||
log_dir = os.path.join(output_dir, "log")
|
||||
state_dir = os.path.join(output_dir, "states")
|
||||
# tensorboard & checkpoint preparation
|
||||
output_dir = args.output
|
||||
ckpt_dir = os.path.join(output_dir, "checkpoints")
|
||||
log_dir = os.path.join(output_dir, "log")
|
||||
state_dir = os.path.join(output_dir, "states")
|
||||
eval_dir = os.path.join(output_dir, "eval")
|
||||
if env.local_rank == 0:
|
||||
make_output_tree(output_dir)
|
||||
writer = SummaryWriter(logdir=log_dir)
|
||||
else:
|
||||
writer = None
|
||||
sentences = [
|
||||
"Scientists at the CERN laboratory say they have discovered a new particle.",
|
||||
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
|
||||
"President Trump met with other leaders at the Group of 20 conference.",
|
||||
"Generative adversarial network or variational auto-encoder.",
|
||||
"Please call Stella.",
|
||||
"Some have accepted this as a miracle without any physical explanation.",
|
||||
]
|
||||
evaluator = make_evaluator(config, sentences, eval_dir, writer)
|
||||
state_saver = make_state_saver(config, state_dir, writer)
|
||||
|
||||
# load parameters and optimizer, and opdate iterations done sofar
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(
|
||||
dv3, optim, checkpoint_path=args.checkpoint)
|
||||
# load parameters and optimizer, and opdate iterations done sofar
|
||||
if args.checkpoint is not None:
|
||||
iteration = load_parameters(
|
||||
model, optim, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = load_parameters(
|
||||
model, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration)
|
||||
|
||||
# =========================train=========================
|
||||
train_config = config["train"]
|
||||
max_iter = train_config["max_iteration"]
|
||||
snap_interval = train_config["snap_interval"]
|
||||
save_interval = train_config["save_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm.tqdm(data_loader))
|
||||
downsample_factor = config["model"]["downsample_factor"]
|
||||
while global_step <= max_iter:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm.tqdm(data_loader))
|
||||
batch = next(iterator)
|
||||
|
||||
model.train()
|
||||
(text_sequences, text_lengths, text_positions, mel_specs, lin_specs,
|
||||
frames, decoder_positions, done_flags) = batch
|
||||
downsampled_mel_specs = F.strided_slice(
|
||||
mel_specs,
|
||||
axes=[1],
|
||||
starts=[0],
|
||||
ends=[mel_specs.shape[1]],
|
||||
strides=[downsample_factor])
|
||||
outputs = model(
|
||||
text_sequences,
|
||||
text_positions,
|
||||
text_lengths,
|
||||
None,
|
||||
downsampled_mel_specs,
|
||||
decoder_positions, )
|
||||
# mel_outputs, linear_outputs, alignments, done
|
||||
inputs = (downsampled_mel_specs, lin_specs, done_flags, text_lengths,
|
||||
frames)
|
||||
losses = criterion(outputs, inputs)
|
||||
|
||||
l = losses["loss"]
|
||||
if env.nranks > 1:
|
||||
l = model.scale_loss(l)
|
||||
l.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
dv3, optim, checkpoint_dir=ckpt_dir, iteration=args.iteration)
|
||||
|
||||
# =========================train=========================
|
||||
max_iter = train_config["max_iteration"]
|
||||
snap_interval = train_config["snap_interval"]
|
||||
save_interval = train_config["save_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm.tqdm(loader))
|
||||
while global_step <= max_iter:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm.tqdm(loader))
|
||||
batch = next(iterator)
|
||||
|
||||
dv3.train()
|
||||
(text_sequences, text_lengths, text_positions, mel_specs,
|
||||
lin_specs, frames, decoder_positions, done_flags) = batch
|
||||
downsampled_mel_specs = F.strided_slice(
|
||||
mel_specs,
|
||||
axes=[1],
|
||||
starts=[0],
|
||||
ends=[mel_specs.shape[1]],
|
||||
strides=[downsample_factor])
|
||||
mel_outputs, linear_outputs, alignments, done = dv3(
|
||||
text_sequences, text_positions, text_lengths, None,
|
||||
downsampled_mel_specs, decoder_positions)
|
||||
|
||||
losses = criterion(mel_outputs, linear_outputs, done, alignments,
|
||||
downsampled_mel_specs, lin_specs, done_flags,
|
||||
text_lengths, frames)
|
||||
l = losses["loss"]
|
||||
l.backward()
|
||||
|
||||
# record learning rate before updating
|
||||
# record learning rate before updating
|
||||
if env.local_rank == 0:
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy(), global_step)
|
||||
optim.minimize(l)
|
||||
optim.clear_gradients()
|
||||
optim.minimize(l)
|
||||
optim.clear_gradients()
|
||||
|
||||
# ==================all kinds of tedious things=================
|
||||
# record step loss into tensorboard
|
||||
step_loss = {
|
||||
k: v.numpy()[0]
|
||||
for k, v in losses.items() if v is not None
|
||||
}
|
||||
tqdm.tqdm.write("global_step: {}\tloss: {}".format(
|
||||
# record step losses
|
||||
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
|
||||
|
||||
if env.local_rank == 0:
|
||||
tqdm.tqdm.write("[Train] global_step: {}\tloss: {}".format(
|
||||
global_step, step_loss["loss"]))
|
||||
for k, v in step_loss.items():
|
||||
writer.add_scalar(k, v, global_step)
|
||||
|
||||
# train state saving, the first sentence in the batch
|
||||
if global_step % snap_interval == 0:
|
||||
save_state(
|
||||
state_dir,
|
||||
writer,
|
||||
global_step,
|
||||
mel_input=downsampled_mel_specs,
|
||||
mel_output=mel_outputs,
|
||||
lin_input=lin_specs,
|
||||
lin_output=linear_outputs,
|
||||
alignments=alignments,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
min_level_db=min_level_db,
|
||||
ref_level_db=ref_level_db,
|
||||
power=power,
|
||||
n_iter=n_iter,
|
||||
preemphasis=preemphasis,
|
||||
sample_rate=sample_rate)
|
||||
# train state saving, the first sentence in the batch
|
||||
if env.local_rank == 0 and global_step % snap_interval == 0:
|
||||
input_specs = (mel_specs, lin_specs)
|
||||
state_saver(outputs, input_specs, global_step)
|
||||
|
||||
# evaluation
|
||||
if global_step % eval_interval == 0:
|
||||
sentences = [
|
||||
"Scientists at the CERN laboratory say they have discovered a new particle.",
|
||||
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
|
||||
"President Trump met with other leaders at the Group of 20 conference.",
|
||||
"Generative adversarial network or variational auto-encoder.",
|
||||
"Please call Stella.",
|
||||
"Some have accepted this as a miracle without any physical explanation.",
|
||||
]
|
||||
for idx, sent in enumerate(sentences):
|
||||
wav, attn = eval_model(
|
||||
dv3, sent, replace_pronounciation_prob, min_level_db,
|
||||
ref_level_db, power, n_iter, win_length, hop_length,
|
||||
preemphasis)
|
||||
wav_path = os.path.join(
|
||||
state_dir, "waveform",
|
||||
"eval_sample_{:09d}.wav".format(global_step))
|
||||
sf.write(wav_path, wav, sample_rate)
|
||||
writer.add_audio(
|
||||
"eval_sample_{}".format(idx),
|
||||
wav,
|
||||
global_step,
|
||||
sample_rate=sample_rate)
|
||||
attn_path = os.path.join(
|
||||
state_dir, "alignments",
|
||||
"eval_sample_attn_{:09d}.png".format(global_step))
|
||||
plot_alignment(attn, attn_path)
|
||||
writer.add_image(
|
||||
"eval_sample_attn{}".format(idx),
|
||||
cm.viridis(attn),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
# evaluation
|
||||
if env.local_rank == 0 and global_step % eval_interval == 0:
|
||||
evaluator(model, global_step)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % save_interval == 0:
|
||||
io.save_parameters(ckpt_dir, global_step, dv3, optim)
|
||||
# save checkpoint
|
||||
if env.local_rank == 0 and global_step % save_interval == 0:
|
||||
save_parameters(ckpt_dir, global_step, model, optim)
|
||||
|
||||
global_step += 1
|
||||
global_step += 1
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
from __future__ import division
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("agg")
|
||||
from matplotlib import cm
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
|
@ -24,132 +26,302 @@ import soundfile as sf
|
|||
|
||||
from paddle import fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.initializer as I
|
||||
|
||||
from parakeet.g2p import en
|
||||
from parakeet.models.deepvoice3.encoder import ConvSpec
|
||||
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, WindowRange
|
||||
from parakeet.utils.layer_tools import freeze
|
||||
|
||||
|
||||
@fluid.framework.dygraph_only
|
||||
def make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
|
||||
padding_idx, embedding_std, max_positions, n_vocab,
|
||||
freeze_embedding, filter_size, encoder_channels, mel_dim,
|
||||
decoder_channels, r, trainable_positional_encodings,
|
||||
use_memory_mask, query_position_rate, key_position_rate,
|
||||
window_behind, window_ahead, key_projection, value_projection,
|
||||
downsample_factor, linear_dim, use_decoder_states,
|
||||
converter_channels, dropout):
|
||||
"""just a simple function to create a deepvoice 3 model"""
|
||||
if n_speakers > 1:
|
||||
spe = dg.Embedding(
|
||||
(n_speakers, speaker_dim),
|
||||
param_attr=I.Normal(scale=speaker_embed_std))
|
||||
def get_place(device_id):
|
||||
"""get place from device_id, -1 stands for CPU"""
|
||||
if device_id == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
spe = None
|
||||
|
||||
h = encoder_channels
|
||||
k = filter_size
|
||||
encoder_convolutions = (
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(h, k, 9),
|
||||
ConvSpec(h, k, 27),
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(h, k, 9),
|
||||
ConvSpec(h, k, 27),
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3), )
|
||||
enc = Encoder(
|
||||
n_vocab,
|
||||
embed_dim,
|
||||
n_speakers,
|
||||
speaker_dim,
|
||||
padding_idx=None,
|
||||
embedding_weight_std=embedding_std,
|
||||
convolutions=encoder_convolutions,
|
||||
dropout=dropout)
|
||||
if freeze_embedding:
|
||||
freeze(enc.embed)
|
||||
|
||||
h = decoder_channels
|
||||
prenet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3))
|
||||
attentive_convolutions = (
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(h, k, 9),
|
||||
ConvSpec(h, k, 27),
|
||||
ConvSpec(h, k, 1), )
|
||||
attention = [True, False, False, False, True]
|
||||
force_monotonic_attention = [True, False, False, False, True]
|
||||
dec = Decoder(
|
||||
n_speakers,
|
||||
speaker_dim,
|
||||
embed_dim,
|
||||
mel_dim,
|
||||
r=r,
|
||||
max_positions=max_positions,
|
||||
preattention=prenet_convolutions,
|
||||
convolutions=attentive_convolutions,
|
||||
attention=attention,
|
||||
dropout=dropout,
|
||||
use_memory_mask=use_memory_mask,
|
||||
force_monotonic_attention=force_monotonic_attention,
|
||||
query_position_rate=query_position_rate,
|
||||
key_position_rate=key_position_rate,
|
||||
window_range=WindowRange(window_behind, window_ahead),
|
||||
key_projection=key_projection,
|
||||
value_projection=value_projection)
|
||||
if not trainable_positional_encodings:
|
||||
freeze(dec.embed_keys_positions)
|
||||
freeze(dec.embed_query_positions)
|
||||
|
||||
h = converter_channels
|
||||
postnet_convolutions = (
|
||||
ConvSpec(h, k, 1),
|
||||
ConvSpec(h, k, 3),
|
||||
ConvSpec(2 * h, k, 1),
|
||||
ConvSpec(2 * h, k, 3), )
|
||||
cvt = Converter(
|
||||
n_speakers,
|
||||
speaker_dim,
|
||||
dec.state_dim if use_decoder_states else mel_dim,
|
||||
linear_dim,
|
||||
time_upsampling=downsample_factor,
|
||||
convolutions=postnet_convolutions,
|
||||
dropout=dropout)
|
||||
dv3 = DeepVoice3(enc, dec, cvt, spe, use_decoder_states)
|
||||
return dv3
|
||||
place = fluid.CUDAPlace(device_id)
|
||||
return place
|
||||
|
||||
|
||||
@fluid.framework.dygraph_only
|
||||
def eval_model(model, text, replace_pronounciation_prob, min_level_db,
|
||||
ref_level_db, power, n_iter, win_length, hop_length,
|
||||
preemphasis):
|
||||
"""generate waveform from text using a deepvoice 3 model"""
|
||||
text = np.array(
|
||||
en.text_to_sequence(
|
||||
text, p=replace_pronounciation_prob),
|
||||
dtype=np.int64)
|
||||
length = len(text)
|
||||
print("text sequence's length: {}".format(length))
|
||||
text_positions = np.arange(1, 1 + length)
|
||||
def add_options(parser):
|
||||
parser.add_argument("--config", type=str, help="experimrnt config")
|
||||
parser.add_argument(
|
||||
"--data",
|
||||
type=str,
|
||||
default="/workspace/datasets/LJSpeech-1.1/",
|
||||
help="The path of the LJSpeech dataset.")
|
||||
parser.add_argument("--device", type=int, default=-1, help="device to use")
|
||||
|
||||
text = np.expand_dims(text, 0)
|
||||
text_positions = np.expand_dims(text_positions, 0)
|
||||
model.eval()
|
||||
mel_outputs, linear_outputs, alignments, done = model.transduce(
|
||||
dg.to_variable(text), dg.to_variable(text_positions))
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from.")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
|
||||
wav = spec_to_waveform(linear_outputs_np, min_level_db, ref_level_db,
|
||||
power, n_iter, win_length, hop_length, preemphasis)
|
||||
alignments_np = alignments.numpy()[0] # batch_size = 1
|
||||
print("linear_outputs's shape: ", linear_outputs_np.shape)
|
||||
print("alignmnets' shape:", alignments.shape)
|
||||
return wav, alignments_np
|
||||
parser.add_argument(
|
||||
"output", type=str, default="experiment", help="path to save results")
|
||||
|
||||
|
||||
def make_evaluator(config, text_sequences, output_dir, writer=None):
|
||||
c = config["transform"]
|
||||
p_replace = c["replace_pronunciation_prob"]
|
||||
sample_rate = c["sample_rate"]
|
||||
preemphasis = c["preemphasis"]
|
||||
win_length = c["win_length"]
|
||||
hop_length = c["hop_length"]
|
||||
min_level_db = c["min_level_db"]
|
||||
ref_level_db = c["ref_level_db"]
|
||||
|
||||
synthesis_config = config["synthesis"]
|
||||
power = synthesis_config["power"]
|
||||
n_iter = synthesis_config["n_iter"]
|
||||
|
||||
return Evaluator(
|
||||
text_sequences,
|
||||
p_replace,
|
||||
sample_rate,
|
||||
preemphasis,
|
||||
win_length,
|
||||
hop_length,
|
||||
min_level_db,
|
||||
ref_level_db,
|
||||
power,
|
||||
n_iter,
|
||||
output_dir=output_dir,
|
||||
writer=writer)
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
def __init__(self,
|
||||
text_sequences,
|
||||
p_replace,
|
||||
sample_rate,
|
||||
preemphasis,
|
||||
win_length,
|
||||
hop_length,
|
||||
min_level_db,
|
||||
ref_level_db,
|
||||
power,
|
||||
n_iter,
|
||||
output_dir,
|
||||
writer=None):
|
||||
self.text_sequences = text_sequences
|
||||
self.output_dir = output_dir
|
||||
self.writer = writer
|
||||
|
||||
self.p_replace = p_replace
|
||||
self.sample_rate = sample_rate
|
||||
self.preemphasis = preemphasis
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.min_level_db = min_level_db
|
||||
self.ref_level_db = ref_level_db
|
||||
|
||||
self.power = power
|
||||
self.n_iter = n_iter
|
||||
|
||||
def process_a_sentence(self, model, text):
|
||||
text = np.array(
|
||||
en.text_to_sequence(
|
||||
text, p=self.p_replace), dtype=np.int64)
|
||||
length = len(text)
|
||||
text_positions = np.arange(1, 1 + length)
|
||||
text = np.expand_dims(text, 0)
|
||||
text_positions = np.expand_dims(text_positions, 0)
|
||||
|
||||
model.eval()
|
||||
if isinstance(model, dg.DataParallel):
|
||||
_model = model._layers
|
||||
else:
|
||||
_model = model
|
||||
mel_outputs, linear_outputs, alignments, done = _model.transduce(
|
||||
dg.to_variable(text), dg.to_variable(text_positions))
|
||||
|
||||
linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
|
||||
|
||||
wav = spec_to_waveform(linear_outputs_np, self.min_level_db,
|
||||
self.ref_level_db, self.power, self.n_iter,
|
||||
self.win_length, self.hop_length,
|
||||
self.preemphasis)
|
||||
alignments_np = alignments.numpy()[0] # batch_size = 1
|
||||
return wav, alignments_np
|
||||
|
||||
def __call__(self, model, iteration):
|
||||
writer = self.writer
|
||||
for i, seq in enumerate(self.text_sequences):
|
||||
print("[Eval] synthesizing sentence {}".format(i))
|
||||
wav, alignments_np = self.process_a_sentence(model, seq)
|
||||
|
||||
wav_path = os.path.join(
|
||||
self.output_dir,
|
||||
"eval_sample_{}_step_{:09d}.wav".format(i, iteration))
|
||||
sf.write(wav_path, wav, self.sample_rate)
|
||||
if writer is not None:
|
||||
writer.add_audio(
|
||||
"eval_sample_{}".format(i),
|
||||
wav,
|
||||
iteration,
|
||||
sample_rate=self.sample_rate)
|
||||
attn_path = os.path.join(
|
||||
self.output_dir,
|
||||
"eval_sample_{}_step_{:09d}.png".format(i, iteration))
|
||||
plot_alignment(alignments_np, attn_path)
|
||||
if writer is not None:
|
||||
writer.add_image(
|
||||
"eval_sample_attn_{}".format(i),
|
||||
cm.viridis(alignments_np),
|
||||
iteration,
|
||||
dataformats="HWC")
|
||||
|
||||
|
||||
def make_state_saver(config, output_dir, writer=None):
|
||||
c = config["transform"]
|
||||
p_replace = c["replace_pronunciation_prob"]
|
||||
sample_rate = c["sample_rate"]
|
||||
preemphasis = c["preemphasis"]
|
||||
win_length = c["win_length"]
|
||||
hop_length = c["hop_length"]
|
||||
min_level_db = c["min_level_db"]
|
||||
ref_level_db = c["ref_level_db"]
|
||||
|
||||
synthesis_config = config["synthesis"]
|
||||
power = synthesis_config["power"]
|
||||
n_iter = synthesis_config["n_iter"]
|
||||
|
||||
return StateSaver(p_replace, sample_rate, preemphasis, win_length,
|
||||
hop_length, min_level_db, ref_level_db, power, n_iter,
|
||||
output_dir, writer)
|
||||
|
||||
|
||||
class StateSaver(object):
|
||||
def __init__(self,
|
||||
p_replace,
|
||||
sample_rate,
|
||||
preemphasis,
|
||||
win_length,
|
||||
hop_length,
|
||||
min_level_db,
|
||||
ref_level_db,
|
||||
power,
|
||||
n_iter,
|
||||
output_dir,
|
||||
writer=None):
|
||||
self.output_dir = output_dir
|
||||
self.writer = writer
|
||||
|
||||
self.p_replace = p_replace
|
||||
self.sample_rate = sample_rate
|
||||
self.preemphasis = preemphasis
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.min_level_db = min_level_db
|
||||
self.ref_level_db = ref_level_db
|
||||
|
||||
self.power = power
|
||||
self.n_iter = n_iter
|
||||
|
||||
def __call__(self, outputs, inputs, iteration):
|
||||
mel_output, lin_output, alignments, done_output = outputs
|
||||
mel_input, lin_input = inputs
|
||||
writer = self.writer
|
||||
|
||||
# mel spectrogram
|
||||
mel_input = mel_input[0].numpy().T
|
||||
mel_output = mel_output[0].numpy().T
|
||||
|
||||
path = os.path.join(self.output_dir, "mel_spec")
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(mel_input)
|
||||
plt.colorbar()
|
||||
plt.title("mel_input")
|
||||
plt.savefig(
|
||||
os.path.join(path, "target_mel_spec_step_{:09d}.png".format(
|
||||
iteration)))
|
||||
plt.close()
|
||||
|
||||
if writer is not None:
|
||||
writer.add_image(
|
||||
"target/mel_spec",
|
||||
cm.viridis(mel_input),
|
||||
iteration,
|
||||
dataformats="HWC")
|
||||
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(mel_output)
|
||||
plt.colorbar()
|
||||
plt.title("mel_output")
|
||||
plt.savefig(
|
||||
os.path.join(path, "predicted_mel_spec_step_{:09d}.png".format(
|
||||
iteration)))
|
||||
plt.close()
|
||||
|
||||
if writer is not None:
|
||||
writer.add_image(
|
||||
"predicted/mel_spec",
|
||||
cm.viridis(mel_output),
|
||||
iteration,
|
||||
dataformats="HWC")
|
||||
|
||||
# linear spectrogram
|
||||
lin_input = lin_input[0].numpy().T
|
||||
lin_output = lin_output[0].numpy().T
|
||||
path = os.path.join(self.output_dir, "lin_spec")
|
||||
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(lin_input)
|
||||
plt.colorbar()
|
||||
plt.title("mel_input")
|
||||
plt.savefig(
|
||||
os.path.join(path, "target_lin_spec_step_{:09d}.png".format(
|
||||
iteration)))
|
||||
plt.close()
|
||||
|
||||
if writer is not None:
|
||||
writer.add_image(
|
||||
"target/lin_spec",
|
||||
cm.viridis(lin_input),
|
||||
iteration,
|
||||
dataformats="HWC")
|
||||
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(lin_output)
|
||||
plt.colorbar()
|
||||
plt.title("mel_input")
|
||||
plt.savefig(
|
||||
os.path.join(path, "predicted_lin_spec_step_{:09d}.png".format(
|
||||
iteration)))
|
||||
plt.close()
|
||||
|
||||
if writer is not None:
|
||||
writer.add_image(
|
||||
"predicted/lin_spec",
|
||||
cm.viridis(lin_output),
|
||||
iteration,
|
||||
dataformats="HWC")
|
||||
|
||||
# alignment
|
||||
path = os.path.join(self.output_dir, "alignments")
|
||||
alignments = alignments[:, 0, :, :].numpy()
|
||||
for idx, attn_layer in enumerate(alignments):
|
||||
save_path = os.path.join(
|
||||
path, "train_attn_layer_{}_step_{}.png".format(idx, iteration))
|
||||
plot_alignment(attn_layer, save_path)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_image(
|
||||
"train_attn/layer_{}".format(idx),
|
||||
cm.viridis(attn_layer),
|
||||
iteration,
|
||||
dataformats="HWC")
|
||||
|
||||
# synthesize waveform
|
||||
wav = spec_to_waveform(
|
||||
lin_output, self.min_level_db, self.ref_level_db, self.power,
|
||||
self.n_iter, self.win_length, self.hop_length, self.preemphasis)
|
||||
path = os.path.join(self.output_dir, "waveform")
|
||||
save_path = os.path.join(
|
||||
path, "train_sample_step_{:09d}.wav".format(iteration))
|
||||
sf.write(save_path, wav, self.sample_rate)
|
||||
|
||||
if writer is not None:
|
||||
writer.add_audio(
|
||||
"train_sample", wav, iteration, sample_rate=self.sample_rate)
|
||||
|
||||
|
||||
def spec_to_waveform(spec, min_level_db, ref_level_db, power, n_iter,
|
||||
|
@ -168,6 +340,7 @@ def spec_to_waveform(spec, min_level_db, ref_level_db, power, n_iter,
|
|||
win_length=win_length)
|
||||
if preemphasis > 0:
|
||||
wav = signal.lfilter([1.], [1., -preemphasis], wav)
|
||||
wav = np.clip(wav, -1.0, 1.0)
|
||||
return wav
|
||||
|
||||
|
||||
|
@ -175,9 +348,9 @@ def make_output_tree(output_dir):
|
|||
print("creating output tree: {}".format(output_dir))
|
||||
ckpt_dir = os.path.join(output_dir, "checkpoints")
|
||||
state_dir = os.path.join(output_dir, "states")
|
||||
log_dir = os.path.join(output_dir, "log")
|
||||
eval_dir = os.path.join(output_dir, "eval")
|
||||
|
||||
for x in [ckpt_dir, state_dir]:
|
||||
for x in [ckpt_dir, state_dir, eval_dir]:
|
||||
if not os.path.exists(x):
|
||||
os.makedirs(x)
|
||||
for x in ["alignments", "waveform", "lin_spec", "mel_spec"]:
|
||||
|
@ -199,130 +372,3 @@ def plot_alignment(alignment, path):
|
|||
plt.ylabel('Decoder timestep')
|
||||
plt.savefig(path)
|
||||
plt.close()
|
||||
|
||||
|
||||
def save_state(save_dir,
|
||||
writer,
|
||||
global_step,
|
||||
mel_input=None,
|
||||
mel_output=None,
|
||||
lin_input=None,
|
||||
lin_output=None,
|
||||
alignments=None,
|
||||
win_length=1024,
|
||||
hop_length=256,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
power=1.4,
|
||||
n_iter=32,
|
||||
preemphasis=0.97,
|
||||
sample_rate=22050):
|
||||
"""Save training intermediate results. Save states for the first sentence in the batch, including
|
||||
mel_spec(predicted, target), lin_spec(predicted, target), attn, waveform.
|
||||
|
||||
Args:
|
||||
save_dir (str): directory to save results.
|
||||
writer (SummaryWriter): tensorboardX summary writer
|
||||
global_step (int): global step.
|
||||
mel_input (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel)
|
||||
mel_output (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel)
|
||||
lin_input (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin)
|
||||
lin_output (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin)
|
||||
alignments (Variable, optional): Defaults to None. Shape(N, B, T_dec, C_enc)
|
||||
wav ([type], optional): Defaults to None. [description]
|
||||
"""
|
||||
|
||||
if mel_input is not None and mel_output is not None:
|
||||
mel_input = mel_input[0].numpy().T
|
||||
mel_output = mel_output[0].numpy().T
|
||||
|
||||
path = os.path.join(save_dir, "mel_spec")
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(mel_input)
|
||||
plt.colorbar()
|
||||
plt.title("mel_input")
|
||||
plt.savefig(
|
||||
os.path.join(path, "target_mel_spec_step{:09d}.png".format(
|
||||
global_step)))
|
||||
plt.close()
|
||||
|
||||
writer.add_image(
|
||||
"target/mel_spec",
|
||||
cm.viridis(mel_input),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(mel_output)
|
||||
plt.colorbar()
|
||||
plt.title("mel_output")
|
||||
plt.savefig(
|
||||
os.path.join(path, "predicted_mel_spec_step{:09d}.png".format(
|
||||
global_step)))
|
||||
plt.close()
|
||||
|
||||
writer.add_image(
|
||||
"predicted/mel_spec",
|
||||
cm.viridis(mel_output),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
|
||||
if lin_input is not None and lin_output is not None:
|
||||
lin_input = lin_input[0].numpy().T
|
||||
lin_output = lin_output[0].numpy().T
|
||||
path = os.path.join(save_dir, "lin_spec")
|
||||
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(lin_input)
|
||||
plt.colorbar()
|
||||
plt.title("mel_input")
|
||||
plt.savefig(
|
||||
os.path.join(path, "target_lin_spec_step{:09d}.png".format(
|
||||
global_step)))
|
||||
plt.close()
|
||||
|
||||
writer.add_image(
|
||||
"target/lin_spec",
|
||||
cm.viridis(lin_input),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
|
||||
plt.figure(figsize=(10, 3))
|
||||
display.specshow(lin_output)
|
||||
plt.colorbar()
|
||||
plt.title("mel_input")
|
||||
plt.savefig(
|
||||
os.path.join(path, "predicted_lin_spec_step{:09d}.png".format(
|
||||
global_step)))
|
||||
plt.close()
|
||||
|
||||
writer.add_image(
|
||||
"predicted/lin_spec",
|
||||
cm.viridis(lin_output),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
|
||||
if alignments is not None and len(alignments.shape) == 4:
|
||||
path = os.path.join(save_dir, "alignments")
|
||||
alignments = alignments[:, 0, :, :].numpy()
|
||||
for idx, attn_layer in enumerate(alignments):
|
||||
save_path = os.path.join(
|
||||
path,
|
||||
"train_attn_layer_{}_step_{}.png".format(idx, global_step))
|
||||
plot_alignment(attn_layer, save_path)
|
||||
|
||||
writer.add_image(
|
||||
"train_attn/layer_{}".format(idx),
|
||||
cm.viridis(attn_layer),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
|
||||
if lin_output is not None:
|
||||
wav = spec_to_waveform(lin_output, min_level_db, ref_level_db, power,
|
||||
n_iter, win_length, hop_length, preemphasis)
|
||||
path = os.path.join(save_dir, "waveform")
|
||||
save_path = os.path.join(
|
||||
path, "train_sample_step_{:09d}.wav".format(global_step))
|
||||
sf.write(save_path, wav, sample_rate)
|
||||
writer.add_audio(
|
||||
"train_sample", wav, global_step, sample_rate=sample_rate)
|
||||
|
|
|
@ -19,6 +19,7 @@ import argparse
|
|||
from tqdm import tqdm
|
||||
from tensorboardX import SummaryWriter
|
||||
from paddle import fluid
|
||||
fluid.require_version('1.8.0')
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
from parakeet.modules.weight_norm import WeightNormWrapper
|
||||
|
@ -55,6 +56,13 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
dg.enable_dygraph(place)
|
||||
|
||||
ljspeech_meta = LJSpeechMetaData(args.data)
|
||||
|
||||
data_config = config["data"]
|
||||
|
@ -99,54 +107,47 @@ if __name__ == "__main__":
|
|||
if not os.path.exists(args.output):
|
||||
os.makedirs(args.output)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
model_config = config["model"]
|
||||
upsampling_factors = model_config["upsampling_factors"]
|
||||
encoder = UpsampleNet(upsampling_factors)
|
||||
|
||||
n_loop = model_config["n_loop"]
|
||||
n_layer = model_config["n_layer"]
|
||||
residual_channels = model_config["residual_channels"]
|
||||
output_dim = model_config["output_dim"]
|
||||
loss_type = model_config["loss_type"]
|
||||
log_scale_min = model_config["log_scale_min"]
|
||||
decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels,
|
||||
filter_size, loss_type, log_scale_min)
|
||||
|
||||
model = ConditionalWavenet(encoder, decoder)
|
||||
summary(model)
|
||||
|
||||
# load model parameters
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
if args.checkpoint:
|
||||
iteration = io.load_parameters(model, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||
assert iteration > 0, "A trained model is needed."
|
||||
|
||||
with dg.guard(place):
|
||||
model_config = config["model"]
|
||||
upsampling_factors = model_config["upsampling_factors"]
|
||||
encoder = UpsampleNet(upsampling_factors)
|
||||
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
||||
# removing weight norm also speeds up computation
|
||||
for layer in model.sublayers():
|
||||
if isinstance(layer, WeightNormWrapper):
|
||||
layer.remove_weight_norm()
|
||||
|
||||
n_loop = model_config["n_loop"]
|
||||
n_layer = model_config["n_layer"]
|
||||
residual_channels = model_config["residual_channels"]
|
||||
output_dim = model_config["output_dim"]
|
||||
loss_type = model_config["loss_type"]
|
||||
log_scale_min = model_config["log_scale_min"]
|
||||
decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
||||
n_mels, filter_size, loss_type, log_scale_min)
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
|
||||
model = ConditionalWavenet(encoder, decoder)
|
||||
summary(model)
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
# load model parameters
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
if args.checkpoint:
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
model, checkpoint_dir=checkpoint_dir, iteration=args.iteration)
|
||||
assert iteration > 0, "A trained model is needed."
|
||||
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||
if not os.path.exists(synthesis_dir):
|
||||
os.makedirs(synthesis_dir)
|
||||
|
||||
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
|
||||
# removing weight norm also speeds up computation
|
||||
for layer in model.sublayers():
|
||||
if isinstance(layer, WeightNormWrapper):
|
||||
layer.remove_weight_norm()
|
||||
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
synthesis_dir = os.path.join(args.output, "synthesis")
|
||||
if not os.path.exists(synthesis_dir):
|
||||
os.makedirs(synthesis_dir)
|
||||
|
||||
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
|
||||
eval_model(model, valid_loader, synthesis_dir, iteration, sample_rate)
|
||||
|
|
|
@ -19,9 +19,10 @@ import argparse
|
|||
import tqdm
|
||||
from tensorboardX import SummaryWriter
|
||||
from paddle import fluid
|
||||
fluid.require_version('1.8.0')
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
|
||||
from parakeet.data import SliceDataset, TransformDataset, CacheDataset, DataCargo, SequentialSampler, RandomSampler
|
||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
|
||||
from parakeet.utils.layer_tools import summary
|
||||
from parakeet.utils import io
|
||||
|
@ -51,6 +52,13 @@ if __name__ == "__main__":
|
|||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
||||
if args.device == -1:
|
||||
place = fluid.CPUPlace()
|
||||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
dg.enable_dygraph(place)
|
||||
|
||||
print("Command Line Args: ")
|
||||
for k, v in vars(args).items():
|
||||
print("{}: {}".format(k, v))
|
||||
|
@ -68,8 +76,9 @@ if __name__ == "__main__":
|
|||
ljspeech = TransformDataset(ljspeech_meta, transform)
|
||||
|
||||
valid_size = data_config["valid_size"]
|
||||
ljspeech_valid = SliceDataset(ljspeech, 0, valid_size)
|
||||
ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech))
|
||||
ljspeech_valid = CacheDataset(SliceDataset(ljspeech, 0, valid_size))
|
||||
ljspeech_train = CacheDataset(
|
||||
SliceDataset(ljspeech, valid_size, len(ljspeech)))
|
||||
|
||||
model_config = config["model"]
|
||||
n_loop = model_config["n_loop"]
|
||||
|
@ -103,93 +112,90 @@ if __name__ == "__main__":
|
|||
else:
|
||||
place = fluid.CUDAPlace(args.device)
|
||||
|
||||
with dg.guard(place):
|
||||
model_config = config["model"]
|
||||
upsampling_factors = model_config["upsampling_factors"]
|
||||
encoder = UpsampleNet(upsampling_factors)
|
||||
model_config = config["model"]
|
||||
upsampling_factors = model_config["upsampling_factors"]
|
||||
encoder = UpsampleNet(upsampling_factors)
|
||||
|
||||
n_loop = model_config["n_loop"]
|
||||
n_layer = model_config["n_layer"]
|
||||
residual_channels = model_config["residual_channels"]
|
||||
output_dim = model_config["output_dim"]
|
||||
loss_type = model_config["loss_type"]
|
||||
log_scale_min = model_config["log_scale_min"]
|
||||
decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
||||
n_mels, filter_size, loss_type, log_scale_min)
|
||||
n_loop = model_config["n_loop"]
|
||||
n_layer = model_config["n_layer"]
|
||||
residual_channels = model_config["residual_channels"]
|
||||
output_dim = model_config["output_dim"]
|
||||
loss_type = model_config["loss_type"]
|
||||
log_scale_min = model_config["log_scale_min"]
|
||||
decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim, n_mels,
|
||||
filter_size, loss_type, log_scale_min)
|
||||
|
||||
model = ConditionalWavenet(encoder, decoder)
|
||||
summary(model)
|
||||
model = ConditionalWavenet(encoder, decoder)
|
||||
summary(model)
|
||||
|
||||
train_config = config["train"]
|
||||
learning_rate = train_config["learning_rate"]
|
||||
anneal_rate = train_config["anneal_rate"]
|
||||
anneal_interval = train_config["anneal_interval"]
|
||||
lr_scheduler = dg.ExponentialDecay(
|
||||
learning_rate, anneal_interval, anneal_rate, staircase=True)
|
||||
gradiant_max_norm = train_config["gradient_max_norm"]
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.ClipByGlobalNorm(gradiant_max_norm))
|
||||
train_config = config["train"]
|
||||
learning_rate = train_config["learning_rate"]
|
||||
anneal_rate = train_config["anneal_rate"]
|
||||
anneal_interval = train_config["anneal_interval"]
|
||||
lr_scheduler = dg.ExponentialDecay(
|
||||
learning_rate, anneal_interval, anneal_rate, staircase=True)
|
||||
gradiant_max_norm = train_config["gradient_max_norm"]
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.ClipByGlobalNorm(gradiant_max_norm))
|
||||
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
train_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
train_loader.set_batch_generator(train_cargo, place)
|
||||
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
valid_loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
valid_loader.set_batch_generator(valid_cargo, place)
|
||||
|
||||
max_iterations = train_config["max_iterations"]
|
||||
checkpoint_interval = train_config["checkpoint_interval"]
|
||||
snap_interval = train_config["snap_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
max_iterations = train_config["max_iterations"]
|
||||
checkpoint_interval = train_config["checkpoint_interval"]
|
||||
snap_interval = train_config["snap_interval"]
|
||||
eval_interval = train_config["eval_interval"]
|
||||
checkpoint_dir = os.path.join(args.output, "checkpoints")
|
||||
log_dir = os.path.join(args.output, "log")
|
||||
writer = SummaryWriter(log_dir)
|
||||
|
||||
# load parameters and optimizer, and update iterations done so far
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(
|
||||
model, optim, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
model,
|
||||
optim,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
iteration=args.iteration)
|
||||
# load parameters and optimizer, and update iterations done so far
|
||||
if args.checkpoint is not None:
|
||||
iteration = io.load_parameters(
|
||||
model, optim, checkpoint_path=args.checkpoint)
|
||||
else:
|
||||
iteration = io.load_parameters(
|
||||
model,
|
||||
optim,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
iteration=args.iteration)
|
||||
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
while global_step <= max_iterations:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
batch = next(iterator)
|
||||
global_step = iteration + 1
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
while global_step <= max_iterations:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
batch = next(iterator)
|
||||
|
||||
audio_clips, mel_specs, audio_starts = batch
|
||||
audio_clips, mel_specs, audio_starts = batch
|
||||
|
||||
model.train()
|
||||
y_var = model(audio_clips, mel_specs, audio_starts)
|
||||
loss_var = model.loss(y_var, audio_clips)
|
||||
loss_var.backward()
|
||||
loss_np = loss_var.numpy()
|
||||
model.train()
|
||||
y_var = model(audio_clips, mel_specs, audio_starts)
|
||||
loss_var = model.loss(y_var, audio_clips)
|
||||
loss_var.backward()
|
||||
loss_np = loss_var.numpy()
|
||||
|
||||
writer.add_scalar("loss", loss_np[0], global_step)
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy()[0],
|
||||
global_step)
|
||||
optim.minimize(loss_var)
|
||||
optim.clear_gradients()
|
||||
print("global_step: {}\tloss: {:<8.6f}".format(global_step,
|
||||
loss_np[0]))
|
||||
writer.add_scalar("loss", loss_np[0], global_step)
|
||||
writer.add_scalar("learning_rate",
|
||||
optim._learning_rate.step().numpy()[0], global_step)
|
||||
optim.minimize(loss_var)
|
||||
optim.clear_gradients()
|
||||
print("global_step: {}\tloss: {:<8.6f}".format(global_step, loss_np[
|
||||
0]))
|
||||
|
||||
if global_step % snap_interval == 0:
|
||||
valid_model(model, valid_loader, writer, global_step,
|
||||
sample_rate)
|
||||
if global_step % snap_interval == 0:
|
||||
valid_model(model, valid_loader, writer, global_step, sample_rate)
|
||||
|
||||
if global_step % checkpoint_interval == 0:
|
||||
io.save_parameters(checkpoint_dir, global_step, model, optim)
|
||||
if global_step % checkpoint_interval == 0:
|
||||
io.save_parameters(checkpoint_dir, global_step, model, optim)
|
||||
|
||||
global_step += 1
|
||||
global_step += 1
|
||||
|
|
|
@ -176,6 +176,79 @@ class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
|
|||
return len(self.sorted_indices)
|
||||
|
||||
|
||||
class BucketSampler(Sampler):
|
||||
def __init__(self,
|
||||
lengths,
|
||||
batch_size=4,
|
||||
batch_group_size=None,
|
||||
permutate=True,
|
||||
num_trainers=1,
|
||||
rank=0):
|
||||
# maybe better implement length as a sort key
|
||||
_lengths = np.array(lengths, dtype=np.int64)
|
||||
self.lengths = np.sort(_lengths)
|
||||
self.sorted_indices = np.argsort(_lengths)
|
||||
self.num_trainers = num_trainers
|
||||
self.rank = rank
|
||||
|
||||
self.dataset_size = len(_lengths)
|
||||
self.num_samples = int(np.ceil(self.dataset_size / num_trainers))
|
||||
self.total_size = self.num_samples * num_trainers
|
||||
assert self.total_size >= self.dataset_size
|
||||
|
||||
self.batch_size = batch_size
|
||||
total_batch_size = num_trainers * batch_size
|
||||
self.total_batch_size = total_batch_size
|
||||
|
||||
if batch_group_size is None:
|
||||
batch_group_size = min(total_batch_size * 32, len(self.lengths))
|
||||
if batch_group_size % total_batch_size != 0:
|
||||
batch_group_size -= batch_group_size % total_batch_size
|
||||
|
||||
self.batch_group_size = batch_group_size
|
||||
assert batch_group_size % total_batch_size == 0
|
||||
self.permutate = permutate
|
||||
|
||||
def __iter__(self):
|
||||
indices = self.sorted_indices
|
||||
|
||||
# Append extra samples to make it evenly distributed on all trainers.
|
||||
num_extras = self.total_size - self.dataset_size
|
||||
extra_indices = np.random.choice(
|
||||
indices, size=(num_extras, ), replace=False)
|
||||
indices = np.concatenate((indices, extra_indices))
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
batch_group_size = self.batch_group_size
|
||||
s, e = 0, 0
|
||||
for i in range(len(indices) // batch_group_size):
|
||||
s = i * batch_group_size
|
||||
e = s + batch_group_size
|
||||
random.shuffle(indices[s:e]) # inplace
|
||||
|
||||
# Permutate batches
|
||||
total_batch_size = self.total_batch_size
|
||||
if self.permutate:
|
||||
perm = np.arange(len(indices[:e]) // total_batch_size)
|
||||
random.shuffle(perm)
|
||||
indices[:e] = indices[:e].reshape(
|
||||
-1, total_batch_size)[perm, :].reshape(-1)
|
||||
|
||||
# Handle last elements
|
||||
s += batch_group_size
|
||||
#print(indices)
|
||||
if s < len(indices):
|
||||
random.shuffle(indices[s:])
|
||||
|
||||
# Subset samples for each trainer.
|
||||
indices = indices[self.rank:self.total_size:self.num_trainers]
|
||||
assert len(indices) == self.num_samples
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sorted_indices)
|
||||
|
||||
|
||||
class WeightedRandomSampler(Sampler):
|
||||
"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
|
||||
Args:
|
||||
|
|
|
@ -15,4 +15,5 @@
|
|||
from parakeet.models.deepvoice3.encoder import Encoder, ConvSpec
|
||||
from parakeet.models.deepvoice3.decoder import Decoder, WindowRange
|
||||
from parakeet.models.deepvoice3.converter import Converter
|
||||
from parakeet.models.deepvoice3.loss import TTSLoss
|
||||
from parakeet.models.deepvoice3.model import DeepVoice3
|
||||
|
|
|
@ -210,98 +210,82 @@ class TTSLoss(object):
|
|||
loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
|
||||
return loss
|
||||
|
||||
def __call__(self,
|
||||
mel_hyp,
|
||||
lin_hyp,
|
||||
done_hyp,
|
||||
attn_hyp,
|
||||
mel_ref,
|
||||
lin_ref,
|
||||
done_ref,
|
||||
input_lengths,
|
||||
n_frames,
|
||||
compute_lin_loss=True,
|
||||
compute_mel_loss=True,
|
||||
compute_done_loss=True,
|
||||
compute_attn_loss=True):
|
||||
def __call__(self, outputs, inputs):
|
||||
"""Total loss
|
||||
|
||||
Args:
|
||||
outpus is a tuple of (mel_hyp, lin_hyp, attn_hyp, done_hyp).
|
||||
mel_hyp (Variable): shape(B, T, C_mel), dtype float32, predicted mel spectrogram.
|
||||
lin_hyp (Variable): shape(B, T, C_lin), dtype float32, predicted linear spectrogram.
|
||||
done_hyp (Variable): shape(B, T), dtype float32, predicted done probability.
|
||||
attn_hyp (Variable): shape(N, B, T_dec, T_enc), dtype float32, predicted attention.
|
||||
|
||||
inputs is a tuple of (mel_ref, lin_ref, done_ref, input_lengths, n_frames)
|
||||
mel_ref (Variable): shape(B, T, C_mel), dtype float32, ground truth mel spectrogram.
|
||||
lin_ref (Variable): shape(B, T, C_lin), dtype float32, ground truth linear spectrogram.
|
||||
done_ref (Variable): shape(B, T), dtype float32, ground truth done flag.
|
||||
input_lengths (Variable): shape(B, ), dtype: int, encoder valid lengths.
|
||||
n_frames (Variable): shape(B, ), dtype: int, decoder valid lengths.
|
||||
compute_lin_loss (bool, optional): whether to compute linear loss. Defaults to True.
|
||||
compute_mel_loss (bool, optional): whether to compute mel loss. Defaults to True.
|
||||
compute_done_loss (bool, optional): whether to compute done loss. Defaults to True.
|
||||
compute_attn_loss (bool, optional): whether to compute atention loss. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict(str, Variable): details of loss.
|
||||
"""
|
||||
total_loss = 0.
|
||||
|
||||
mel_hyp, lin_hyp, attn_hyp, done_hyp = outputs
|
||||
mel_ref, lin_ref, done_ref, input_lengths, n_frames = inputs
|
||||
|
||||
# n_frames # mel_lengths # decoder_lengths
|
||||
max_frames = lin_hyp.shape[1]
|
||||
max_mel_steps = max_frames // self.downsample_factor
|
||||
max_decoder_steps = max_mel_steps // self.r
|
||||
|
||||
decoder_mask = F.sequence_mask(
|
||||
n_frames // self.downsample_factor // self.r,
|
||||
max_decoder_steps,
|
||||
dtype="float32")
|
||||
# max_decoder_steps = max_mel_steps // self.r
|
||||
# decoder_mask = F.sequence_mask(n_frames // self.downsample_factor //
|
||||
# self.r,
|
||||
# max_decoder_steps,
|
||||
# dtype="float32")
|
||||
mel_mask = F.sequence_mask(
|
||||
n_frames // self.downsample_factor, max_mel_steps, dtype="float32")
|
||||
lin_mask = F.sequence_mask(n_frames, max_frames, dtype="float32")
|
||||
|
||||
if compute_lin_loss:
|
||||
lin_hyp = lin_hyp[:, :-self.time_shift, :]
|
||||
lin_ref = lin_ref[:, self.time_shift:, :]
|
||||
lin_mask = lin_mask[:, self.time_shift:]
|
||||
lin_l1_loss = self.l1_loss(
|
||||
lin_hyp, lin_ref, lin_mask, priority_bin=self.priority_bin)
|
||||
lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
|
||||
lin_loss = self.binary_divergence_weight * lin_bce_loss \
|
||||
+ (1 - self.binary_divergence_weight) * lin_l1_loss
|
||||
total_loss += lin_loss
|
||||
lin_hyp = lin_hyp[:, :-self.time_shift, :]
|
||||
lin_ref = lin_ref[:, self.time_shift:, :]
|
||||
lin_mask = lin_mask[:, self.time_shift:]
|
||||
lin_l1_loss = self.l1_loss(
|
||||
lin_hyp, lin_ref, lin_mask, priority_bin=self.priority_bin)
|
||||
lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
|
||||
lin_loss = self.binary_divergence_weight * lin_bce_loss \
|
||||
+ (1 - self.binary_divergence_weight) * lin_l1_loss
|
||||
total_loss += lin_loss
|
||||
|
||||
if compute_mel_loss:
|
||||
mel_hyp = mel_hyp[:, :-self.time_shift, :]
|
||||
mel_ref = mel_ref[:, self.time_shift:, :]
|
||||
mel_mask = mel_mask[:, self.time_shift:]
|
||||
mel_l1_loss = self.l1_loss(mel_hyp, mel_ref, mel_mask)
|
||||
mel_bce_loss = self.binary_divergence(mel_hyp, mel_ref, mel_mask)
|
||||
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
|
||||
mel_loss = self.binary_divergence_weight * mel_bce_loss \
|
||||
+ (1 - self.binary_divergence_weight) * mel_l1_loss
|
||||
total_loss += mel_loss
|
||||
mel_hyp = mel_hyp[:, :-self.time_shift, :]
|
||||
mel_ref = mel_ref[:, self.time_shift:, :]
|
||||
mel_mask = mel_mask[:, self.time_shift:]
|
||||
mel_l1_loss = self.l1_loss(mel_hyp, mel_ref, mel_mask)
|
||||
mel_bce_loss = self.binary_divergence(mel_hyp, mel_ref, mel_mask)
|
||||
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
|
||||
mel_loss = self.binary_divergence_weight * mel_bce_loss \
|
||||
+ (1 - self.binary_divergence_weight) * mel_l1_loss
|
||||
total_loss += mel_loss
|
||||
|
||||
if compute_attn_loss:
|
||||
attn_loss = self.attention_loss(attn_hyp,
|
||||
input_lengths.numpy(),
|
||||
n_frames.numpy() //
|
||||
(self.downsample_factor * self.r))
|
||||
total_loss += attn_loss
|
||||
attn_loss = self.attention_loss(attn_hyp,
|
||||
input_lengths.numpy(),
|
||||
n_frames.numpy() //
|
||||
(self.downsample_factor * self.r))
|
||||
total_loss += attn_loss
|
||||
|
||||
if compute_done_loss:
|
||||
done_loss = self.done_loss(done_hyp, done_ref)
|
||||
total_loss += done_loss
|
||||
done_loss = self.done_loss(done_hyp, done_ref)
|
||||
total_loss += done_loss
|
||||
|
||||
result = {
|
||||
losses = {
|
||||
"loss": total_loss,
|
||||
"mel/mel_loss": mel_loss if compute_mel_loss else None,
|
||||
"mel/l1_loss": mel_l1_loss if compute_mel_loss else None,
|
||||
"mel/bce_loss": mel_bce_loss if compute_mel_loss else None,
|
||||
"lin/lin_loss": lin_loss if compute_lin_loss else None,
|
||||
"lin/l1_loss": lin_l1_loss if compute_lin_loss else None,
|
||||
"lin/bce_loss": lin_bce_loss if compute_lin_loss else None,
|
||||
"done": done_loss if compute_done_loss else None,
|
||||
"attn": attn_loss if compute_attn_loss else None,
|
||||
"mel/mel_loss": mel_loss,
|
||||
"mel/l1_loss": mel_l1_loss,
|
||||
"mel/bce_loss": mel_bce_loss,
|
||||
"lin/lin_loss": lin_loss,
|
||||
"lin/l1_loss": lin_l1_loss,
|
||||
"lin/bce_loss": lin_bce_loss,
|
||||
"done": done_loss,
|
||||
"attn": attn_loss,
|
||||
}
|
||||
|
||||
return result
|
||||
return losses
|
||||
|
|
|
@ -19,6 +19,34 @@ import paddle.fluid.layers as F
|
|||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
def lookup(weight, indices, padding_idx):
|
||||
out = fluid.core.ops.lookup_table_v2(
|
||||
weight, indices, 'is_sparse', False, 'is_distributed', False,
|
||||
'remote_prefetch', False, 'padding_idx', padding_idx)
|
||||
return out
|
||||
|
||||
|
||||
def compute_position_embedding_single_speaker(radians, speaker_position_rate):
|
||||
"""Compute sin/cos interleaved matrix from the radians.
|
||||
|
||||
Arg:
|
||||
radians (Variable): shape(n_vocab, embed_dim), dtype float32, the radians matrix.
|
||||
speaker_position_rate (float or Variable): float or Variable of shape(1, ), speaker positioning rate.
|
||||
|
||||
Returns:
|
||||
Variable: shape(n_vocab, embed_dim), the sin, cos interleaved matrix.
|
||||
"""
|
||||
_, embed_dim = radians.shape
|
||||
scaled_radians = radians * speaker_position_rate
|
||||
|
||||
odd_mask = (np.arange(embed_dim) % 2).astype(np.float32)
|
||||
odd_mask = dg.to_variable(odd_mask)
|
||||
|
||||
out = odd_mask * F.cos(scaled_radians) \
|
||||
+ (1 - odd_mask) * F.sin(scaled_radians)
|
||||
return out
|
||||
|
||||
|
||||
def compute_position_embedding(radians, speaker_position_rate):
|
||||
"""Compute sin/cos interleaved matrix from the radians.
|
||||
|
||||
|
@ -106,16 +134,14 @@ class PositionEmbedding(dg.Layer):
|
|||
"""
|
||||
batch_size, time_steps = indices.shape
|
||||
|
||||
# convert speaker_position_rate to a Variable with shape(B, )
|
||||
if isinstance(speaker_position_rate, float):
|
||||
speaker_position_rate = dg.to_variable(
|
||||
np.array([speaker_position_rate]).astype("float32"))
|
||||
speaker_position_rate = F.expand(speaker_position_rate,
|
||||
[batch_size])
|
||||
elif isinstance(speaker_position_rate, fluid.framework.Variable) \
|
||||
and list(speaker_position_rate.shape) == [1]:
|
||||
speaker_position_rate = F.expand(speaker_position_rate,
|
||||
[batch_size])
|
||||
if isinstance(speaker_position_rate, float) or \
|
||||
(isinstance(speaker_position_rate, fluid.framework.Variable)
|
||||
and list(speaker_position_rate.shape) == [1]):
|
||||
temp_weight = compute_position_embedding_single_speaker(
|
||||
self.weight, speaker_position_rate)
|
||||
out = lookup(temp_weight, indices, 0)
|
||||
return out
|
||||
|
||||
assert len(speaker_position_rate.shape) == 1 and \
|
||||
list(speaker_position_rate.shape) == [batch_size]
|
||||
|
||||
|
@ -128,6 +154,5 @@ class PositionEmbedding(dg.Layer):
|
|||
0, batch_size, 1, dtype="int64"), [1]), [1, time_steps])
|
||||
# (B, T, 2)
|
||||
gather_nd_id = F.stack([batch_id, indices], -1)
|
||||
|
||||
out = F.gather_nd(weight, gather_nd_id)
|
||||
return out
|
||||
|
|
|
@ -57,14 +57,44 @@ def norm_except(param, dim, power):
|
|||
return norm_except(transposed_param, dim=0, power=power)
|
||||
|
||||
|
||||
def compute_weight(v, g, dim, power):
|
||||
assert len(g.shape) == 1, "magnitude should be a vector"
|
||||
v_normalized = F.elementwise_div(
|
||||
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
||||
def compute_l2_normalized_weight(v, g, dim):
|
||||
shape = v.shape
|
||||
ndim = len(shape)
|
||||
|
||||
if dim is None:
|
||||
v_normalized = v / (F.reduce_sum(F.square(v)) + 1e-12)
|
||||
elif dim == 0:
|
||||
param_matrix = F.reshape(v, (shape[0], np.prod(shape[1:])))
|
||||
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
||||
elif dim == -1 or dim == ndim - 1:
|
||||
param_matrix = F.reshape(v, (np.prod(shape[:-1]), shape[-1]))
|
||||
v_normalized = F.l2_normalize(param_matrix, axis=0)
|
||||
else:
|
||||
perm = list(range(ndim))
|
||||
perm[0] = dim
|
||||
perm[dim] = 0
|
||||
transposed_param = F.transpose(v, perm)
|
||||
param_matrix = F.reshape(
|
||||
transposed_param,
|
||||
(transposed_param.shape[0], np.prod(transposed_param.shape[1:])))
|
||||
v_normalized = F.l2_normalize(param_matrix, axis=1)
|
||||
v_normalized = F.transpose(v_normalized, perm)
|
||||
v_normalized = F.reshape(v_normalized, shape)
|
||||
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
||||
return weight
|
||||
|
||||
|
||||
def compute_weight(v, g, dim, power):
|
||||
assert len(g.shape) == 1, "magnitude should be a vector"
|
||||
if power == 2:
|
||||
return compute_l2_normalized_weight(v, g, dim)
|
||||
else:
|
||||
v_normalized = F.elementwise_div(
|
||||
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
||||
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
||||
return weight
|
||||
|
||||
|
||||
class WeightNormWrapper(dg.Layer):
|
||||
def __init__(self, layer, param_name="weight", dim=0, power=2):
|
||||
super(WeightNormWrapper, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue