Merge branch 'update_waveflow' into 'master'
Update waveflow See merge request !21
This commit is contained in:
commit
25883dcd3e
|
@ -0,0 +1,27 @@
|
|||
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
|
||||
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
|
||||
hooks:
|
||||
- id: yapf
|
||||
files: \.py$
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
sha: a11d9314b22d8f8c7556443875b731ef05965464
|
||||
hooks:
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
files: (?!.*paddle)^.*$
|
||||
- id: end-of-file-fixer
|
||||
files: \.md$
|
||||
- id: trailing-whitespace
|
||||
files: \.md$
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
sha: v1.0.1
|
||||
hooks:
|
||||
- id: forbid-crlf
|
||||
files: \.md$
|
||||
- id: remove-crlf
|
||||
files: \.md$
|
||||
- id: forbid-tabs
|
||||
files: \.md$
|
||||
- id: remove-tabs
|
||||
files: \.md$
|
|
@ -28,22 +28,21 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Train a deepvoice 3 model with LJSpeech dataset.")
|
||||
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
|
||||
parser.add_argument("-s",
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--data",
|
||||
type=str,
|
||||
default="/workspace/datasets/LJSpeech-1.1/",
|
||||
help="The path of the LJSpeech dataset.")
|
||||
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
|
||||
parser.add_argument("-o",
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
type=str,
|
||||
default="result",
|
||||
help="The directory to save result.")
|
||||
parser.add_argument("-g",
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="device to use")
|
||||
parser.add_argument(
|
||||
"-g", "--device", type=int, default=-1, help="device to use")
|
||||
args, _ = parser.parse_known_args()
|
||||
with open(args.config, 'rt') as f:
|
||||
config = ruamel.yaml.safe_load(f)
|
||||
|
@ -84,18 +83,16 @@ if __name__ == "__main__":
|
|||
train_config = config["train"]
|
||||
batch_size = train_config["batch_size"]
|
||||
text_lengths = [len(example[2]) for example in meta]
|
||||
sampler = PartialyRandomizedSimilarTimeLengthSampler(
|
||||
text_lengths, batch_size)
|
||||
sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths,
|
||||
batch_size)
|
||||
|
||||
# some hyperparameters affect how we process data, so create a data collector!
|
||||
model_config = config["model"]
|
||||
downsample_factor = model_config["downsample_factor"]
|
||||
r = model_config["outputs_per_step"]
|
||||
collector = DataCollector(downsample_factor=downsample_factor, r=r)
|
||||
ljspeech_loader = DataCargo(ljspeech,
|
||||
batch_fn=collector,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler)
|
||||
ljspeech_loader = DataCargo(
|
||||
ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler)
|
||||
|
||||
# =========================model=========================
|
||||
if args.device == -1:
|
||||
|
@ -131,15 +128,14 @@ if __name__ == "__main__":
|
|||
window_ahead = model_config["window_ahead"]
|
||||
key_projection = model_config["key_projection"]
|
||||
value_projection = model_config["value_projection"]
|
||||
dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
|
||||
padding_idx, embedding_std, max_positions, n_vocab,
|
||||
freeze_embedding, filter_size, encoder_channels,
|
||||
n_mels, decoder_channels, r,
|
||||
dv3 = make_model(
|
||||
n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx,
|
||||
embedding_std, max_positions, n_vocab, freeze_embedding,
|
||||
filter_size, encoder_channels, n_mels, decoder_channels, r,
|
||||
trainable_positional_encodings, use_memory_mask,
|
||||
query_position_rate, key_position_rate,
|
||||
window_backward, window_ahead, key_projection,
|
||||
value_projection, downsample_factor, linear_dim,
|
||||
use_decoder_states, converter_channels, dropout)
|
||||
query_position_rate, key_position_rate, window_backward,
|
||||
window_ahead, key_projection, value_projection, downsample_factor,
|
||||
linear_dim, use_decoder_states, converter_channels, dropout)
|
||||
|
||||
# =========================loss=========================
|
||||
loss_config = config["loss"]
|
||||
|
@ -149,7 +145,8 @@ if __name__ == "__main__":
|
|||
priority_freq_weight = loss_config["priority_freq_weight"]
|
||||
binary_divergence_weight = loss_config["binary_divergence_weight"]
|
||||
guided_attention_sigma = loss_config["guided_attention_sigma"]
|
||||
criterion = TTSLoss(masked_weight=masked_weight,
|
||||
criterion = TTSLoss(
|
||||
masked_weight=masked_weight,
|
||||
priority_bin=priority_bin,
|
||||
priority_weight=priority_freq_weight,
|
||||
binary_divergence_weight=binary_divergence_weight,
|
||||
|
@ -169,7 +166,8 @@ if __name__ == "__main__":
|
|||
beta1 = optim_config["beta1"]
|
||||
beta2 = optim_config["beta2"]
|
||||
epsilon = optim_config["epsilon"]
|
||||
optim = fluid.optimizer.Adam(lr_scheduler,
|
||||
optim = fluid.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
beta1,
|
||||
beta2,
|
||||
epsilon=epsilon,
|
||||
|
@ -183,8 +181,8 @@ if __name__ == "__main__":
|
|||
|
||||
# =========================link(dataloader, paddle)=========================
|
||||
# CAUTION: it does not return a DataLoader
|
||||
loader = fluid.io.DataLoader.from_generator(capacity=10,
|
||||
return_list=True)
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
capacity=10, return_list=True)
|
||||
loader.set_batch_generator(ljspeech_loader, places=place)
|
||||
|
||||
# tensorboard & checkpoint preparation
|
||||
|
@ -247,7 +245,8 @@ if __name__ == "__main__":
|
|||
# TODO: clean code
|
||||
# train state saving, the first sentence in the batch
|
||||
if global_step % snap_interval == 0:
|
||||
save_state(state_dir,
|
||||
save_state(
|
||||
state_dir,
|
||||
writer,
|
||||
global_step,
|
||||
mel_input=downsampled_mel_specs,
|
||||
|
@ -275,16 +274,16 @@ if __name__ == "__main__":
|
|||
"Some have accepted this as a miracle without any physical explanation.",
|
||||
]
|
||||
for idx, sent in enumerate(sentences):
|
||||
wav, attn = eval_model(dv3, sent,
|
||||
replace_pronounciation_prob,
|
||||
min_level_db, ref_level_db,
|
||||
power, n_iter, win_length,
|
||||
hop_length, preemphasis)
|
||||
wav, attn = eval_model(
|
||||
dv3, sent, replace_pronounciation_prob,
|
||||
min_level_db, ref_level_db, power, n_iter,
|
||||
win_length, hop_length, preemphasis)
|
||||
wav_path = os.path.join(
|
||||
state_dir, "waveform",
|
||||
"eval_sample_{:09d}.wav".format(global_step))
|
||||
sf.write(wav_path, wav, sample_rate)
|
||||
writer.add_audio("eval_sample_{}".format(idx),
|
||||
writer.add_audio(
|
||||
"eval_sample_{}".format(idx),
|
||||
wav,
|
||||
global_step,
|
||||
sample_rate=sample_rate)
|
||||
|
@ -292,7 +291,8 @@ if __name__ == "__main__":
|
|||
state_dir, "alignments",
|
||||
"eval_sample_attn_{:09d}.png".format(global_step))
|
||||
plot_alignment(attn, attn_path)
|
||||
writer.add_image("eval_sample_attn{}".format(idx),
|
||||
writer.add_image(
|
||||
"eval_sample_attn{}".format(idx),
|
||||
cm.viridis(attn),
|
||||
global_step,
|
||||
dataformats="HWC")
|
||||
|
|
|
@ -2,35 +2,47 @@ import os
|
|||
import random
|
||||
from pprint import pprint
|
||||
|
||||
import jsonargparse
|
||||
import argparse
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
|
||||
import utils
|
||||
from waveflow import WaveFlow
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
parser.add_argument('--model', type=str, default='waveflow',
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='waveflow',
|
||||
help="general name of the model")
|
||||
parser.add_argument('--name', type=str,
|
||||
help="specific name of the training model")
|
||||
parser.add_argument('--root', type=str,
|
||||
help="root path of the LJSpeech dataset")
|
||||
parser.add_argument(
|
||||
'--name', type=str, help="specific name of the training model")
|
||||
parser.add_argument(
|
||||
'--root', type=str, help="root path of the LJSpeech dataset")
|
||||
|
||||
parser.add_argument('--use_gpu', type=bool, default=True,
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=bool,
|
||||
default=True,
|
||||
help="option to use gpu training")
|
||||
|
||||
parser.add_argument('--iteration', type=int, default=None,
|
||||
parser.add_argument(
|
||||
'--iteration',
|
||||
type=int,
|
||||
default=None,
|
||||
help=("which iteration of checkpoint to load, "
|
||||
"default to load the latest checkpoint"))
|
||||
parser.add_argument('--checkpoint', type=str, default=None,
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the checkpoint to load")
|
||||
|
||||
|
||||
def benchmark(config):
|
||||
pprint(jsonargparse.namespace_to_dict(config))
|
||||
pprint(vars(config))
|
||||
|
||||
# Get checkpoint directory path.
|
||||
run_dir = os.path.join("runs", config.model, config.name)
|
||||
|
@ -58,9 +70,8 @@ def benchmark(config):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Create parser.
|
||||
parser = jsonargparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model",
|
||||
formatter_class='default_argparse')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model")
|
||||
add_options_to_parser(parser)
|
||||
utils.add_config_options_to_parser(parser)
|
||||
|
||||
|
@ -68,4 +79,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
benchmark(config)
|
|
@ -2,40 +2,58 @@ import os
|
|||
import random
|
||||
from pprint import pprint
|
||||
|
||||
import jsonargparse
|
||||
import argparse
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
|
||||
import utils
|
||||
from waveflow import WaveFlow
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
parser.add_argument('--model', type=str, default='waveflow',
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='waveflow',
|
||||
help="general name of the model")
|
||||
parser.add_argument('--name', type=str,
|
||||
help="specific name of the training model")
|
||||
parser.add_argument('--root', type=str,
|
||||
help="root path of the LJSpeech dataset")
|
||||
parser.add_argument(
|
||||
'--name', type=str, help="specific name of the training model")
|
||||
parser.add_argument(
|
||||
'--root', type=str, help="root path of the LJSpeech dataset")
|
||||
|
||||
parser.add_argument('--use_gpu', type=bool, default=True,
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=bool,
|
||||
default=True,
|
||||
help="option to use gpu training")
|
||||
|
||||
parser.add_argument('--iteration', type=int, default=None,
|
||||
parser.add_argument(
|
||||
'--iteration',
|
||||
type=int,
|
||||
default=None,
|
||||
help=("which iteration of checkpoint to load, "
|
||||
"default to load the latest checkpoint"))
|
||||
parser.add_argument('--checkpoint', type=str, default=None,
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the checkpoint to load")
|
||||
|
||||
parser.add_argument('--output', type=str, default="./syn_audios",
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=str,
|
||||
default="./syn_audios",
|
||||
help="path to write synthesized audio files")
|
||||
parser.add_argument('--sample', type=int, default=None,
|
||||
parser.add_argument(
|
||||
'--sample',
|
||||
type=int,
|
||||
default=None,
|
||||
help="which of the valid samples to synthesize audio")
|
||||
|
||||
|
||||
def synthesize(config):
|
||||
pprint(jsonargparse.namespace_to_dict(config))
|
||||
pprint(vars(config))
|
||||
|
||||
# Get checkpoint directory path.
|
||||
run_dir = os.path.join("runs", config.model, config.name)
|
||||
|
@ -72,9 +90,8 @@ def synthesize(config):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Create parser.
|
||||
parser = jsonargparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model",
|
||||
formatter_class='default_argparse')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model")
|
||||
add_options_to_parser(parser)
|
||||
utils.add_config_options_to_parser(parser)
|
||||
|
||||
|
@ -82,4 +99,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
synthesize(config)
|
|
@ -4,34 +4,48 @@ import subprocess
|
|||
import time
|
||||
from pprint import pprint
|
||||
|
||||
import jsonargparse
|
||||
import argparse
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
import slurm
|
||||
import utils
|
||||
from waveflow import WaveFlow
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
parser.add_argument('--model', type=str, default='waveflow',
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='waveflow',
|
||||
help="general name of the model")
|
||||
parser.add_argument('--name', type=str,
|
||||
help="specific name of the training model")
|
||||
parser.add_argument('--root', type=str,
|
||||
help="root path of the LJSpeech dataset")
|
||||
parser.add_argument(
|
||||
'--name', type=str, help="specific name of the training model")
|
||||
parser.add_argument(
|
||||
'--root', type=str, help="root path of the LJSpeech dataset")
|
||||
|
||||
parser.add_argument('--parallel', type=bool, default=True,
|
||||
parser.add_argument(
|
||||
'--parallel',
|
||||
type=utils.str2bool,
|
||||
default=True,
|
||||
help="option to use data parallel training")
|
||||
parser.add_argument('--use_gpu', type=bool, default=True,
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=utils.str2bool,
|
||||
default=True,
|
||||
help="option to use gpu training")
|
||||
|
||||
parser.add_argument('--iteration', type=int, default=None,
|
||||
parser.add_argument(
|
||||
'--iteration',
|
||||
type=int,
|
||||
default=None,
|
||||
help=("which iteration of checkpoint to load, "
|
||||
"default to load the latest checkpoint"))
|
||||
parser.add_argument('--checkpoint', type=str, default=None,
|
||||
parser.add_argument(
|
||||
'--checkpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help="path of the checkpoint to load")
|
||||
|
||||
|
||||
|
@ -45,12 +59,13 @@ def train(config):
|
|||
|
||||
if rank == 0:
|
||||
# Print the whole config setting.
|
||||
pprint(jsonargparse.namespace_to_dict(config))
|
||||
pprint(vars(config))
|
||||
|
||||
# Make checkpoint directory.
|
||||
run_dir = os.path.join("runs", config.model, config.name)
|
||||
checkpoint_dir = os.path.join(run_dir, "checkpoint")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
|
||||
# Create tensorboard logger.
|
||||
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
|
||||
|
@ -102,8 +117,8 @@ def train(config):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Create parser.
|
||||
parser = jsonargparse.ArgumentParser(description="Train WaveFlow model",
|
||||
formatter_class='default_argparse')
|
||||
parser = argparse.ArgumentParser(description="Train WaveFlow model")
|
||||
#formatter_class='default_argparse')
|
||||
add_options_to_parser(parser)
|
||||
utils.add_config_options_to_parser(parser)
|
||||
|
||||
|
@ -111,4 +126,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
train(config)
|
|
@ -2,59 +2,96 @@ import itertools
|
|||
import os
|
||||
import time
|
||||
|
||||
import jsonargparse
|
||||
import argparse
|
||||
import ruamel.yaml
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument('--valid_size', type=int,
|
||||
help="size of the valid dataset")
|
||||
parser.add_argument('--segment_length', type=int,
|
||||
parser.add_argument(
|
||||
'--valid_size', type=int, help="size of the valid dataset")
|
||||
parser.add_argument(
|
||||
'--segment_length',
|
||||
type=int,
|
||||
help="the length of audio clip for training")
|
||||
parser.add_argument('--sample_rate', type=int,
|
||||
help="sampling rate of audio data file")
|
||||
parser.add_argument('--fft_window_shift', type=int,
|
||||
parser.add_argument(
|
||||
'--sample_rate', type=int, help="sampling rate of audio data file")
|
||||
parser.add_argument(
|
||||
'--fft_window_shift',
|
||||
type=int,
|
||||
help="the shift of fft window for each frame")
|
||||
parser.add_argument('--fft_window_size', type=int,
|
||||
parser.add_argument(
|
||||
'--fft_window_size',
|
||||
type=int,
|
||||
help="the size of fft window for each frame")
|
||||
parser.add_argument('--fft_size', type=int,
|
||||
help="the size of fft filter on each frame")
|
||||
parser.add_argument('--mel_bands', type=int,
|
||||
parser.add_argument(
|
||||
'--fft_size', type=int, help="the size of fft filter on each frame")
|
||||
parser.add_argument(
|
||||
'--mel_bands',
|
||||
type=int,
|
||||
help="the number of mel bands when calculating mel spectrograms")
|
||||
parser.add_argument('--mel_fmin', type=float,
|
||||
parser.add_argument(
|
||||
'--mel_fmin',
|
||||
type=float,
|
||||
help="lowest frequency in calculating mel spectrograms")
|
||||
parser.add_argument('--mel_fmax', type=float,
|
||||
parser.add_argument(
|
||||
'--mel_fmax',
|
||||
type=float,
|
||||
help="highest frequency in calculating mel spectrograms")
|
||||
|
||||
parser.add_argument('--seed', type=int,
|
||||
help="seed of random initialization for the model")
|
||||
parser.add_argument(
|
||||
'--seed', type=int, help="seed of random initialization for the model")
|
||||
parser.add_argument('--learning_rate', type=float)
|
||||
parser.add_argument('--batch_size', type=int,
|
||||
help="batch size for training")
|
||||
parser.add_argument('--test_every', type=int,
|
||||
help="test interval during training")
|
||||
parser.add_argument('--save_every', type=int,
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, help="batch size for training")
|
||||
parser.add_argument(
|
||||
'--test_every', type=int, help="test interval during training")
|
||||
parser.add_argument(
|
||||
'--save_every',
|
||||
type=int,
|
||||
help="checkpointing interval during training")
|
||||
parser.add_argument('--max_iterations', type=int,
|
||||
help="maximum training iterations")
|
||||
parser.add_argument(
|
||||
'--max_iterations', type=int, help="maximum training iterations")
|
||||
|
||||
parser.add_argument('--sigma', type=float,
|
||||
parser.add_argument(
|
||||
'--sigma',
|
||||
type=float,
|
||||
help="standard deviation of the latent Gaussian variable")
|
||||
parser.add_argument('--n_flows', type=int,
|
||||
help="number of flows")
|
||||
parser.add_argument('--n_group', type=int,
|
||||
parser.add_argument('--n_flows', type=int, help="number of flows")
|
||||
parser.add_argument(
|
||||
'--n_group',
|
||||
type=int,
|
||||
help="number of adjacent audio samples to squeeze into one column")
|
||||
parser.add_argument('--n_layers', type=int,
|
||||
parser.add_argument(
|
||||
'--n_layers',
|
||||
type=int,
|
||||
help="number of conv2d layer in one wavenet-like flow architecture")
|
||||
parser.add_argument('--n_channels', type=int,
|
||||
help="number of residual channels in flow")
|
||||
parser.add_argument('--kernel_h', type=int,
|
||||
parser.add_argument(
|
||||
'--n_channels', type=int, help="number of residual channels in flow")
|
||||
parser.add_argument(
|
||||
'--kernel_h',
|
||||
type=int,
|
||||
help="height of the kernel in the conv2d layer")
|
||||
parser.add_argument('--kernel_w', type=int,
|
||||
help="width of the kernel in the conv2d layer")
|
||||
parser.add_argument(
|
||||
'--kernel_w', type=int, help="width of the kernel in the conv2d layer")
|
||||
|
||||
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
|
||||
parser.add_argument('--config', type=str, help="Path to the config file.")
|
||||
|
||||
|
||||
def add_yaml_config(config):
|
||||
with open(config.config, 'rt') as f:
|
||||
yaml_cfg = ruamel.yaml.safe_load(f)
|
||||
cfg_vars = vars(config)
|
||||
for k, v in yaml_cfg.items():
|
||||
if k in cfg_vars and cfg_vars[k] is not None:
|
||||
continue
|
||||
cfg_vars[k] = v
|
||||
return config
|
||||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||
|
@ -84,8 +121,12 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
|
|||
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||
|
||||
|
||||
def load_parameters(checkpoint_dir, rank, model, optimizer=None,
|
||||
iteration=None, file_path=None):
|
||||
def load_parameters(checkpoint_dir,
|
||||
rank,
|
||||
model,
|
||||
optimizer=None,
|
||||
iteration=None,
|
||||
file_path=None):
|
||||
if file_path is None:
|
||||
if iteration is None:
|
||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
|
@ -5,21 +5,26 @@ import librosa
|
|||
from .. import g2p
|
||||
|
||||
from ..data.sampler import SequentialSampler, RandomSampler, BatchSampler
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset import DatasetMixin
|
||||
from ..data.datacargo import DataCargo
|
||||
from ..data.batch import TextIDBatcher, SpecBatcher
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
class LJSpeech(DatasetMixin):
|
||||
def __init__(self, root):
|
||||
super(LJSpeech, self).__init__()
|
||||
assert isinstance(root, (str, Path)), "root should be a string or Path object"
|
||||
assert isinstance(root, (
|
||||
str, Path)), "root should be a string or Path object"
|
||||
self.root = root if isinstance(root, Path) else Path(root)
|
||||
self.metadata = self._prepare_metadata()
|
||||
|
||||
def _prepare_metadata(self):
|
||||
csv_path = self.root.joinpath("metadata.csv")
|
||||
metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3,
|
||||
metadata = pd.read_csv(
|
||||
csv_path,
|
||||
sep="|",
|
||||
header=None,
|
||||
quoting=3,
|
||||
names=["fname", "raw_text", "normalized_text"])
|
||||
return metadata
|
||||
|
||||
|
@ -35,7 +40,9 @@ class LJSpeech(Dataset):
|
|||
wav_path = self.root.joinpath("wavs", fname + ".wav")
|
||||
|
||||
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
|
||||
wav, sample_rate = librosa.load(wav_path, sr=None) # we would rather use functor to hold its parameters
|
||||
wav, sample_rate = librosa.load(
|
||||
wav_path,
|
||||
sr=None) # we would rather use functor to hold its parameters
|
||||
trimed, _ = librosa.effects.trim(wav)
|
||||
preemphasized = librosa.effects.preemphasis(trimed)
|
||||
D = librosa.stft(preemphasized)
|
||||
|
@ -50,8 +57,10 @@ class LJSpeech(Dataset):
|
|||
mel = np.clip((mel - ref_db + max_db) / max_db, 1e-8, 1)
|
||||
mel = np.clip((mag - ref_db + max_db) / max_db, 1e-8, 1)
|
||||
|
||||
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
|
||||
phonemes = np.array(
|
||||
g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||
return (mag, mel, phonemes
|
||||
) # maybe we need to implement it as a map in the future
|
||||
|
||||
def _batch_examples(self, minibatch):
|
||||
mag_batch = []
|
||||
|
@ -78,5 +87,3 @@ class LJSpeech(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.metadata)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from parakeet.models.waveflow.waveflow import WaveFlow
|
|
@ -5,10 +5,9 @@ import numpy as np
|
|||
from paddle import fluid
|
||||
|
||||
from parakeet.datasets import ljspeech
|
||||
from parakeet.data import dataset
|
||||
from parakeet.data.batch import SpecBatcher, WavBatcher
|
||||
from parakeet.data.datacargo import DataCargo
|
||||
from parakeet.data.sampler import DistributedSampler, BatchSampler
|
||||
from parakeet.data import SpecBatcher, WavBatcher
|
||||
from parakeet.data import DataCargo, DatasetMixin
|
||||
from parakeet.data import DistributedSampler, BatchSampler
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
|
||||
|
@ -27,7 +26,7 @@ class Dataset(ljspeech.LJSpeech):
|
|||
return audio
|
||||
|
||||
|
||||
class Subset(dataset.Dataset):
|
||||
class Subset(DatasetMixin):
|
||||
def __init__(self, dataset, indices, valid):
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
|
@ -36,14 +35,14 @@ class Subset(dataset.Dataset):
|
|||
|
||||
def get_mel(self, audio):
|
||||
spectrogram = librosa.core.stft(
|
||||
audio, n_fft=self.config.fft_size,
|
||||
audio,
|
||||
n_fft=self.config.fft_size,
|
||||
hop_length=self.config.fft_window_shift,
|
||||
win_length=self.config.fft_window_size)
|
||||
spectrogram_magnitude = np.abs(spectrogram)
|
||||
|
||||
# mel_filter_bank shape: [n_mels, 1 + n_fft/2]
|
||||
mel_filter_bank = librosa.filters.mel(
|
||||
sr=self.config.sample_rate,
|
||||
mel_filter_bank = librosa.filters.mel(sr=self.config.sample_rate,
|
||||
n_fft=self.config.fft_size,
|
||||
n_mels=self.config.mel_bands,
|
||||
fmin=self.config.mel_fmin,
|
||||
|
@ -73,7 +72,8 @@ class Subset(dataset.Dataset):
|
|||
audio = audio[audio_start:(audio_start + segment_length)]
|
||||
else:
|
||||
audio = np.pad(audio, (0, segment_length - audio.shape[0]),
|
||||
mode='constant', constant_values=0)
|
||||
mode='constant',
|
||||
constant_values=0)
|
||||
|
||||
# Normalize audio to the [-1, 1] range.
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
|
@ -112,8 +112,8 @@ class LJSpeech:
|
|||
sampler = DistributedSampler(len(trainset), nranks, rank)
|
||||
total_bs = config.batch_size
|
||||
assert total_bs % nranks == 0
|
||||
train_sampler = BatchSampler(sampler, total_bs // nranks,
|
||||
drop_last=True)
|
||||
train_sampler = BatchSampler(
|
||||
sampler, total_bs // nranks, drop_last=True)
|
||||
trainloader = DataCargo(trainset, batch_sampler=train_sampler)
|
||||
|
||||
trainreader = fluid.io.PyReader(capacity=50, return_list=True)
|
||||
|
|
|
@ -8,13 +8,18 @@ from paddle import fluid
|
|||
from scipy.io.wavfile import write
|
||||
|
||||
import utils
|
||||
from data import LJSpeech
|
||||
from waveflow_modules import WaveFlowLoss, WaveFlowModule
|
||||
from .data import LJSpeech
|
||||
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
||||
|
||||
|
||||
class WaveFlow():
|
||||
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
|
||||
nranks=1, tb_logger=None):
|
||||
def __init__(self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
parallel=False,
|
||||
rank=0,
|
||||
nranks=1,
|
||||
tb_logger=None):
|
||||
self.config = config
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.parallel = parallel
|
||||
|
@ -28,7 +33,7 @@ class WaveFlow():
|
|||
self.trainloader = dataset.trainloader
|
||||
self.validloader = dataset.validloader
|
||||
|
||||
waveflow = WaveFlowModule("waveflow", config)
|
||||
waveflow = WaveFlowModule(config)
|
||||
|
||||
# Dry run once to create and initalize all necessary parameters.
|
||||
audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32))
|
||||
|
@ -38,11 +43,15 @@ class WaveFlow():
|
|||
|
||||
if training:
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=config.learning_rate)
|
||||
learning_rate=config.learning_rate,
|
||||
parameter_list=waveflow.parameters())
|
||||
|
||||
# Load parameters.
|
||||
utils.load_parameters(self.checkpoint_dir, self.rank,
|
||||
waveflow, optimizer,
|
||||
utils.load_parameters(
|
||||
self.checkpoint_dir,
|
||||
self.rank,
|
||||
waveflow,
|
||||
optimizer,
|
||||
iteration=config.iteration,
|
||||
file_path=config.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(self.rank))
|
||||
|
@ -58,7 +67,10 @@ class WaveFlow():
|
|||
|
||||
else:
|
||||
# Load parameters.
|
||||
utils.load_parameters(self.checkpoint_dir, self.rank, waveflow,
|
||||
utils.load_parameters(
|
||||
self.checkpoint_dir,
|
||||
self.rank,
|
||||
waveflow,
|
||||
iteration=config.iteration,
|
||||
file_path=config.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(self.rank))
|
||||
|
@ -83,7 +95,8 @@ class WaveFlow():
|
|||
else:
|
||||
loss.backward()
|
||||
|
||||
self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters())
|
||||
self.optimizer.minimize(
|
||||
loss, parameter_list=self.waveflow.parameters())
|
||||
self.waveflow.clear_gradients()
|
||||
|
||||
graph_time = time.time()
|
||||
|
@ -139,7 +152,8 @@ class WaveFlow():
|
|||
sample = config.sample
|
||||
|
||||
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
||||
os.makedirs(output, exist_ok=True)
|
||||
if not os.path.exists(output):
|
||||
os.makedirs(output)
|
||||
|
||||
mels_list = [mels for _, mels in self.validloader()]
|
||||
if sample is not None:
|
||||
|
@ -155,8 +169,8 @@ class WaveFlow():
|
|||
|
||||
audio = audio[0]
|
||||
audio_time = audio.shape[0] / self.config.sample_rate
|
||||
print("audio time {:.4f}, synthesis time {:.4f}".format(
|
||||
audio_time, syn_time))
|
||||
print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
|
||||
syn_time))
|
||||
|
||||
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
|
||||
audio = audio.numpy() * 32768.0
|
||||
|
@ -180,8 +194,8 @@ class WaveFlow():
|
|||
syn_time = time.time() - start_time
|
||||
|
||||
audio_time = audio.shape[1] * batch_size / self.config.sample_rate
|
||||
print("audio time {:.4f}, synthesis time {:.4f}".format(
|
||||
audio_time, syn_time))
|
||||
print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
|
||||
syn_time))
|
||||
print("{} X real-time".format(audio_time / syn_time))
|
||||
|
||||
def save(self, iteration):
|
||||
|
|
|
@ -3,22 +3,23 @@ import itertools
|
|||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
from parakeet.modules import conv, modules, weight_norm
|
||||
from parakeet.modules import weight_norm
|
||||
|
||||
|
||||
def set_param_attr(layer, c_in=1):
|
||||
if isinstance(layer, (weight_norm.Conv2DTranspose, weight_norm.Conv2D)):
|
||||
k = np.sqrt(1.0 / (c_in * np.prod(layer._filter_size)))
|
||||
def get_param_attr(layer_type, filter_size, c_in=1):
|
||||
if layer_type == "weight_norm":
|
||||
k = np.sqrt(1.0 / (c_in * np.prod(filter_size)))
|
||||
weight_init = fluid.initializer.UniformInitializer(low=-k, high=k)
|
||||
bias_init = fluid.initializer.UniformInitializer(low=-k, high=k)
|
||||
elif isinstance(layer, dg.Conv2D):
|
||||
elif layer_type == "common":
|
||||
weight_init = fluid.initializer.ConstantInitializer(0.0)
|
||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
||||
else:
|
||||
raise TypeError("Unsupported layer type.")
|
||||
|
||||
layer._param_attr = fluid.ParamAttr(initializer=weight_init)
|
||||
layer._bias_attr = fluid.ParamAttr(initializer=bias_init)
|
||||
param_attr = fluid.ParamAttr(initializer=weight_init)
|
||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
||||
return param_attr, bias_attr
|
||||
|
||||
|
||||
def unfold(x, n_group):
|
||||
|
@ -48,20 +49,23 @@ class WaveFlowLoss:
|
|||
|
||||
|
||||
class Conditioner(dg.Layer):
|
||||
def __init__(self, name_scope):
|
||||
super(Conditioner, self).__init__(name_scope)
|
||||
def __init__(self):
|
||||
super(Conditioner, self).__init__()
|
||||
upsample_factors = [16, 16]
|
||||
|
||||
self.upsample_conv2d = []
|
||||
for s in upsample_factors:
|
||||
in_channel = 1
|
||||
conv_trans2d = modules.Conv2DTranspose(
|
||||
self.full_name(),
|
||||
param_attr, bias_attr = get_param_attr(
|
||||
"weight_norm", (3, 2 * s), c_in=in_channel)
|
||||
conv_trans2d = weight_norm.Conv2DTranspose(
|
||||
num_channels=in_channel,
|
||||
num_filters=1,
|
||||
filter_size=(3, 2 * s),
|
||||
padding=(1, s // 2),
|
||||
stride=(1, s))
|
||||
set_param_attr(conv_trans2d, c_in=in_channel)
|
||||
stride=(1, s),
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.upsample_conv2d.append(conv_trans2d)
|
||||
|
||||
for i, layer in enumerate(self.upsample_conv2d):
|
||||
|
@ -86,8 +90,8 @@ class Conditioner(dg.Layer):
|
|||
|
||||
|
||||
class Flow(dg.Layer):
|
||||
def __init__(self, name_scope, config):
|
||||
super(Flow, self).__init__(name_scope)
|
||||
def __init__(self, config):
|
||||
super(Flow, self).__init__()
|
||||
self.n_layers = config.n_layers
|
||||
self.n_channels = config.n_channels
|
||||
self.kernel_h = config.kernel_h
|
||||
|
@ -95,27 +99,34 @@ class Flow(dg.Layer):
|
|||
|
||||
# Transform audio: [batch, 1, n_group, time/n_group]
|
||||
# => [batch, n_channels, n_group, time/n_group]
|
||||
param_attr, bias_attr = get_param_attr("weight_norm", (1, 1), c_in=1)
|
||||
self.start = weight_norm.Conv2D(
|
||||
self.full_name(),
|
||||
num_channels=1,
|
||||
num_filters=self.n_channels,
|
||||
filter_size=(1, 1))
|
||||
set_param_attr(self.start, c_in=1)
|
||||
filter_size=(1, 1),
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
|
||||
# Initializing last layer to 0 makes the affine coupling layers
|
||||
# do nothing at first. This helps with training stability
|
||||
# output shape: [batch, 2, n_group, time/n_group]
|
||||
param_attr, bias_attr = get_param_attr(
|
||||
"common", (1, 1), c_in=self.n_channels)
|
||||
self.end = dg.Conv2D(
|
||||
self.full_name(),
|
||||
num_channels=self.n_channels,
|
||||
num_filters=2,
|
||||
filter_size=(1, 1))
|
||||
set_param_attr(self.end)
|
||||
filter_size=(1, 1),
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
|
||||
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
|
||||
dilation_dict = {8: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
dilation_dict = {
|
||||
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
16: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
32: [1, 2, 4, 1, 2, 4, 1, 2],
|
||||
64: [1, 2, 4, 8, 16, 1, 2, 4],
|
||||
128: [1, 2, 4, 8, 16, 32, 64, 1]}
|
||||
128: [1, 2, 4, 8, 16, 32, 64, 1]
|
||||
}
|
||||
self.dilation_h_list = dilation_dict[config.n_group]
|
||||
|
||||
self.in_layers = []
|
||||
|
@ -125,30 +136,40 @@ class Flow(dg.Layer):
|
|||
dilation_h = self.dilation_h_list[i]
|
||||
dilation_w = 2**i
|
||||
|
||||
param_attr, bias_attr = get_param_attr(
|
||||
"weight_norm", (self.kernel_h, self.kernel_w),
|
||||
c_in=self.n_channels)
|
||||
in_layer = weight_norm.Conv2D(
|
||||
self.full_name(),
|
||||
num_channels=self.n_channels,
|
||||
num_filters=2 * self.n_channels,
|
||||
filter_size=(self.kernel_h, self.kernel_w),
|
||||
dilation=(dilation_h, dilation_w))
|
||||
set_param_attr(in_layer, c_in=self.n_channels)
|
||||
dilation=(dilation_h, dilation_w),
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
param_attr, bias_attr = get_param_attr(
|
||||
"weight_norm", (1, 1), c_in=config.mel_bands)
|
||||
cond_layer = weight_norm.Conv2D(
|
||||
self.full_name(),
|
||||
num_channels=config.mel_bands,
|
||||
num_filters=2 * self.n_channels,
|
||||
filter_size=(1, 1))
|
||||
set_param_attr(cond_layer, c_in=config.mel_bands)
|
||||
filter_size=(1, 1),
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.cond_layers.append(cond_layer)
|
||||
|
||||
if i < self.n_layers - 1:
|
||||
res_skip_channels = 2 * self.n_channels
|
||||
else:
|
||||
res_skip_channels = self.n_channels
|
||||
param_attr, bias_attr = get_param_attr(
|
||||
"weight_norm", (1, 1), c_in=self.n_channels)
|
||||
res_skip_layer = weight_norm.Conv2D(
|
||||
self.full_name(),
|
||||
num_channels=self.n_channels,
|
||||
num_filters=res_skip_channels,
|
||||
filter_size=(1, 1))
|
||||
set_param_attr(res_skip_layer, c_in=self.n_channels)
|
||||
filter_size=(1, 1),
|
||||
param_attr=param_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
|
||||
self.add_sublayer("in_layer_{}".format(i), in_layer)
|
||||
|
@ -168,8 +189,8 @@ class Flow(dg.Layer):
|
|||
# Pad width dim (time): dialated non-causal convolution
|
||||
pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0
|
||||
pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2)
|
||||
audio_pad = fluid.layers.pad2d(audio,
|
||||
paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||
audio_pad = fluid.layers.pad2d(
|
||||
audio, paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||
|
||||
hidden = self.in_layers[i](audio_pad)
|
||||
cond_hidden = self.cond_layers[i](mel)
|
||||
|
@ -206,7 +227,7 @@ class Flow(dg.Layer):
|
|||
queue.append(fluid.layers.zeros_like(audio))
|
||||
|
||||
state = queue[0:state_size]
|
||||
state = fluid.layers.concat([*state, audio], axis=2)
|
||||
state = fluid.layers.concat(state + [audio], axis=2)
|
||||
|
||||
queue.pop(0)
|
||||
queue.append(audio)
|
||||
|
@ -216,8 +237,8 @@ class Flow(dg.Layer):
|
|||
pad_top, pad_bottom = 0, 0
|
||||
pad_left = int((self.kernel_w - 1) * dilation_w / 2)
|
||||
pad_right = int((self.kernel_w - 1) * dilation_w / 2)
|
||||
state = fluid.layers.pad2d(state,
|
||||
paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||
state = fluid.layers.pad2d(
|
||||
state, paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||
|
||||
hidden = self.in_layers[i](state)
|
||||
cond_hidden = self.cond_layers[i](mel)
|
||||
|
@ -241,18 +262,18 @@ class Flow(dg.Layer):
|
|||
|
||||
|
||||
class WaveFlowModule(dg.Layer):
|
||||
def __init__(self, name_scope, config):
|
||||
super(WaveFlowModule, self).__init__(name_scope)
|
||||
def __init__(self, config):
|
||||
super(WaveFlowModule, self).__init__()
|
||||
self.n_flows = config.n_flows
|
||||
self.n_group = config.n_group
|
||||
self.n_layers = config.n_layers
|
||||
assert self.n_group % 2 == 0
|
||||
assert self.n_flows % 2 == 0
|
||||
|
||||
self.conditioner = Conditioner(self.full_name())
|
||||
self.conditioner = Conditioner()
|
||||
self.flows = []
|
||||
for i in range(self.n_flows):
|
||||
flow = Flow(self.full_name(), config)
|
||||
flow = Flow(config)
|
||||
self.flows.append(flow)
|
||||
self.add_sublayer("flow_{}".format(i), flow)
|
||||
|
||||
|
@ -284,7 +305,6 @@ class WaveFlowModule(dg.Layer):
|
|||
audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1])
|
||||
# [bs, 1, n_group, time/n_group]
|
||||
audio = fluid.layers.unsqueeze(audio, 1)
|
||||
|
||||
log_s_list = []
|
||||
for i in range(self.n_flows):
|
||||
inputs = audio[:, :, :-1, :]
|
||||
|
@ -305,7 +325,6 @@ class WaveFlowModule(dg.Layer):
|
|||
mel = fluid.layers.stack(mel_slices, axis=2)
|
||||
|
||||
z = fluid.layers.squeeze(audio, [1])
|
||||
|
||||
return z, log_s_list
|
||||
|
||||
def synthesize(self, mel, sigma=1.0):
|
||||
|
|
|
@ -40,8 +40,8 @@ def norm_except(param, dim, power):
|
|||
|
||||
def compute_weight(v, g, dim, power):
|
||||
assert len(g.shape) == 1, "magnitude should be a vector"
|
||||
v_normalized = F.elementwise_div(v, (norm_except(v, dim, power) + 1e-12),
|
||||
axis=dim)
|
||||
v_normalized = F.elementwise_div(
|
||||
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
|
||||
weight = F.elementwise_mul(v_normalized, g, axis=dim)
|
||||
return weight
|
||||
|
||||
|
@ -63,20 +63,21 @@ class WeightNormWrapper(dg.Layer):
|
|||
original_weight = getattr(layer, param_name)
|
||||
self.add_parameter(
|
||||
w_v,
|
||||
self.create_parameter(shape=original_weight.shape,
|
||||
dtype=original_weight.dtype))
|
||||
self.create_parameter(
|
||||
shape=original_weight.shape, dtype=original_weight.dtype))
|
||||
F.assign(original_weight, getattr(self, w_v))
|
||||
delattr(layer, param_name)
|
||||
temp = norm_except(getattr(self, w_v), self.dim, self.power)
|
||||
self.add_parameter(
|
||||
w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype))
|
||||
w_g, self.create_parameter(
|
||||
shape=temp.shape, dtype=temp.dtype))
|
||||
F.assign(temp, getattr(self, w_g))
|
||||
|
||||
# also set this when setting up
|
||||
setattr(
|
||||
self.layer, self.param_name,
|
||||
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,
|
||||
self.power))
|
||||
setattr(self.layer, self.param_name,
|
||||
compute_weight(
|
||||
getattr(self, w_v),
|
||||
getattr(self, w_g), self.dim, self.power))
|
||||
|
||||
self.weigth_norm_applied = True
|
||||
|
||||
|
@ -84,10 +85,10 @@ class WeightNormWrapper(dg.Layer):
|
|||
def hook(self):
|
||||
w_v = self.param_name + "_v"
|
||||
w_g = self.param_name + "_g"
|
||||
setattr(
|
||||
self.layer, self.param_name,
|
||||
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,
|
||||
self.power))
|
||||
setattr(self.layer, self.param_name,
|
||||
compute_weight(
|
||||
getattr(self, w_v),
|
||||
getattr(self, w_g), self.dim, self.power))
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.hook()
|
||||
|
|
Loading…
Reference in New Issue