Merge branch 'update_waveflow' into 'master'

Update waveflow

See merge request !21
This commit is contained in:
liuyibing01 2020-02-24 11:07:13 +08:00
commit 25883dcd3e
14 changed files with 458 additions and 302 deletions

27
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,27 @@
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$

View File

@ -28,22 +28,21 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument("-s",
"--data",
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument(
"-s",
"--data",
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
parser.add_argument("-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument("-g",
"--device",
type=int,
default=-1,
help="device to use")
parser.add_argument(
"-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument(
"-g", "--device", type=int, default=-1, help="device to use")
args, _ = parser.parse_known_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
@ -84,18 +83,16 @@ if __name__ == "__main__":
train_config = config["train"]
batch_size = train_config["batch_size"]
text_lengths = [len(example[2]) for example in meta]
sampler = PartialyRandomizedSimilarTimeLengthSampler(
text_lengths, batch_size)
sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths,
batch_size)
# some hyperparameters affect how we process data, so create a data collector!
model_config = config["model"]
downsample_factor = model_config["downsample_factor"]
r = model_config["outputs_per_step"]
collector = DataCollector(downsample_factor=downsample_factor, r=r)
ljspeech_loader = DataCargo(ljspeech,
batch_fn=collector,
batch_size=batch_size,
sampler=sampler)
ljspeech_loader = DataCargo(
ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler)
# =========================model=========================
if args.device == -1:
@ -131,15 +128,14 @@ if __name__ == "__main__":
window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"]
dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
padding_idx, embedding_std, max_positions, n_vocab,
freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate,
window_backward, window_ahead, key_projection,
value_projection, downsample_factor, linear_dim,
use_decoder_states, converter_channels, dropout)
dv3 = make_model(
n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx,
embedding_std, max_positions, n_vocab, freeze_embedding,
filter_size, encoder_channels, n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate, window_backward,
window_ahead, key_projection, value_projection, downsample_factor,
linear_dim, use_decoder_states, converter_channels, dropout)
# =========================loss=========================
loss_config = config["loss"]
@ -149,13 +145,14 @@ if __name__ == "__main__":
priority_freq_weight = loss_config["priority_freq_weight"]
binary_divergence_weight = loss_config["binary_divergence_weight"]
guided_attention_sigma = loss_config["guided_attention_sigma"]
criterion = TTSLoss(masked_weight=masked_weight,
priority_bin=priority_bin,
priority_weight=priority_freq_weight,
binary_divergence_weight=binary_divergence_weight,
guided_attention_sigma=guided_attention_sigma,
downsample_factor=downsample_factor,
r=r)
criterion = TTSLoss(
masked_weight=masked_weight,
priority_bin=priority_bin,
priority_weight=priority_freq_weight,
binary_divergence_weight=binary_divergence_weight,
guided_attention_sigma=guided_attention_sigma,
downsample_factor=downsample_factor,
r=r)
# =========================lr_scheduler=========================
lr_config = config["lr_scheduler"]
@ -169,11 +166,12 @@ if __name__ == "__main__":
beta1 = optim_config["beta1"]
beta2 = optim_config["beta2"]
epsilon = optim_config["epsilon"]
optim = fluid.optimizer.Adam(lr_scheduler,
beta1,
beta2,
epsilon=epsilon,
parameter_list=dv3.parameters())
optim = fluid.optimizer.Adam(
lr_scheduler,
beta1,
beta2,
epsilon=epsilon,
parameter_list=dv3.parameters())
gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1)
# generation
@ -183,8 +181,8 @@ if __name__ == "__main__":
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(capacity=10,
return_list=True)
loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
# tensorboard & checkpoint preparation
@ -247,22 +245,23 @@ if __name__ == "__main__":
# TODO: clean code
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
save_state(state_dir,
writer,
global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
save_state(
state_dir,
writer,
global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
# evaluation
if global_step % eval_interval == 0:
@ -275,27 +274,28 @@ if __name__ == "__main__":
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in enumerate(sentences):
wav, attn = eval_model(dv3, sent,
replace_pronounciation_prob,
min_level_db, ref_level_db,
power, n_iter, win_length,
hop_length, preemphasis)
wav, attn = eval_model(
dv3, sent, replace_pronounciation_prob,
min_level_db, ref_level_db, power, n_iter,
win_length, hop_length, preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
writer.add_audio("eval_sample_{}".format(idx),
wav,
global_step,
sample_rate=sample_rate)
writer.add_audio(
"eval_sample_{}".format(idx),
wav,
global_step,
sample_rate=sample_rate)
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
writer.add_image("eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
writer.add_image(
"eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
# save checkpoint
if global_step % save_interval == 0:

View File

@ -2,35 +2,47 @@ import os
import random
from pprint import pprint
import jsonargparse
import argparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from waveflow import WaveFlow
from parakeet.models.waveflow import WaveFlow
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='waveflow',
parser.add_argument(
'--model',
type=str,
default='waveflow',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument(
'--name', type=str, help="specific name of the training model")
parser.add_argument(
'--root', type=str, help="root path of the LJSpeech dataset")
parser.add_argument('--use_gpu', type=bool, default=True,
parser.add_argument(
'--use_gpu',
type=bool,
default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
parser.add_argument(
'--iteration',
type=int,
default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help="path of the checkpoint to load")
def benchmark(config):
pprint(jsonargparse.namespace_to_dict(config))
pprint(vars(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
@ -58,9 +70,8 @@ def benchmark(config):
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
parser = argparse.ArgumentParser(
description="Synthesize audio using WaveNet model")
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
@ -68,4 +79,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
benchmark(config)

View File

@ -2,40 +2,58 @@ import os
import random
from pprint import pprint
import jsonargparse
import argparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from waveflow import WaveFlow
from parakeet.models.waveflow import WaveFlow
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='waveflow',
parser.add_argument(
'--model',
type=str,
default='waveflow',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument(
'--name', type=str, help="specific name of the training model")
parser.add_argument(
'--root', type=str, help="root path of the LJSpeech dataset")
parser.add_argument('--use_gpu', type=bool, default=True,
parser.add_argument(
'--use_gpu',
type=bool,
default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
parser.add_argument(
'--iteration',
type=int,
default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help="path of the checkpoint to load")
parser.add_argument('--output', type=str, default="./syn_audios",
parser.add_argument(
'--output',
type=str,
default="./syn_audios",
help="path to write synthesized audio files")
parser.add_argument('--sample', type=int, default=None,
parser.add_argument(
'--sample',
type=int,
default=None,
help="which of the valid samples to synthesize audio")
def synthesize(config):
pprint(jsonargparse.namespace_to_dict(config))
pprint(vars(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
@ -72,9 +90,8 @@ def synthesize(config):
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
parser = argparse.ArgumentParser(
description="Synthesize audio using WaveNet model")
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
@ -82,4 +99,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
synthesize(config)

View File

@ -4,34 +4,48 @@ import subprocess
import time
from pprint import pprint
import jsonargparse
import argparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from tensorboardX import SummaryWriter
import slurm
import utils
from waveflow import WaveFlow
from parakeet.models.waveflow import WaveFlow
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='waveflow',
parser.add_argument(
'--model',
type=str,
default='waveflow',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument(
'--name', type=str, help="specific name of the training model")
parser.add_argument(
'--root', type=str, help="root path of the LJSpeech dataset")
parser.add_argument('--parallel', type=bool, default=True,
parser.add_argument(
'--parallel',
type=utils.str2bool,
default=True,
help="option to use data parallel training")
parser.add_argument('--use_gpu', type=bool, default=True,
parser.add_argument(
'--use_gpu',
type=utils.str2bool,
default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
parser.add_argument(
'--iteration',
type=int,
default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help="path of the checkpoint to load")
@ -45,12 +59,13 @@ def train(config):
if rank == 0:
# Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(config))
pprint(vars(config))
# Make checkpoint directory.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Create tensorboard logger.
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
@ -102,8 +117,8 @@ def train(config):
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(description="Train WaveFlow model",
formatter_class='default_argparse')
parser = argparse.ArgumentParser(description="Train WaveFlow model")
#formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
@ -111,4 +126,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
train(config)

View File

@ -2,59 +2,96 @@ import itertools
import os
import time
import jsonargparse
import argparse
import ruamel.yaml
import numpy as np
import paddle.fluid.dygraph as dg
def str2bool(v):
return v.lower() in ("true", "t", "1")
def add_config_options_to_parser(parser):
parser.add_argument('--valid_size', type=int,
help="size of the valid dataset")
parser.add_argument('--segment_length', type=int,
parser.add_argument(
'--valid_size', type=int, help="size of the valid dataset")
parser.add_argument(
'--segment_length',
type=int,
help="the length of audio clip for training")
parser.add_argument('--sample_rate', type=int,
help="sampling rate of audio data file")
parser.add_argument('--fft_window_shift', type=int,
parser.add_argument(
'--sample_rate', type=int, help="sampling rate of audio data file")
parser.add_argument(
'--fft_window_shift',
type=int,
help="the shift of fft window for each frame")
parser.add_argument('--fft_window_size', type=int,
parser.add_argument(
'--fft_window_size',
type=int,
help="the size of fft window for each frame")
parser.add_argument('--fft_size', type=int,
help="the size of fft filter on each frame")
parser.add_argument('--mel_bands', type=int,
parser.add_argument(
'--fft_size', type=int, help="the size of fft filter on each frame")
parser.add_argument(
'--mel_bands',
type=int,
help="the number of mel bands when calculating mel spectrograms")
parser.add_argument('--mel_fmin', type=float,
parser.add_argument(
'--mel_fmin',
type=float,
help="lowest frequency in calculating mel spectrograms")
parser.add_argument('--mel_fmax', type=float,
parser.add_argument(
'--mel_fmax',
type=float,
help="highest frequency in calculating mel spectrograms")
parser.add_argument('--seed', type=int,
help="seed of random initialization for the model")
parser.add_argument(
'--seed', type=int, help="seed of random initialization for the model")
parser.add_argument('--learning_rate', type=float)
parser.add_argument('--batch_size', type=int,
help="batch size for training")
parser.add_argument('--test_every', type=int,
help="test interval during training")
parser.add_argument('--save_every', type=int,
parser.add_argument(
'--batch_size', type=int, help="batch size for training")
parser.add_argument(
'--test_every', type=int, help="test interval during training")
parser.add_argument(
'--save_every',
type=int,
help="checkpointing interval during training")
parser.add_argument('--max_iterations', type=int,
help="maximum training iterations")
parser.add_argument(
'--max_iterations', type=int, help="maximum training iterations")
parser.add_argument('--sigma', type=float,
parser.add_argument(
'--sigma',
type=float,
help="standard deviation of the latent Gaussian variable")
parser.add_argument('--n_flows', type=int,
help="number of flows")
parser.add_argument('--n_group', type=int,
parser.add_argument('--n_flows', type=int, help="number of flows")
parser.add_argument(
'--n_group',
type=int,
help="number of adjacent audio samples to squeeze into one column")
parser.add_argument('--n_layers', type=int,
parser.add_argument(
'--n_layers',
type=int,
help="number of conv2d layer in one wavenet-like flow architecture")
parser.add_argument('--n_channels', type=int,
help="number of residual channels in flow")
parser.add_argument('--kernel_h', type=int,
parser.add_argument(
'--n_channels', type=int, help="number of residual channels in flow")
parser.add_argument(
'--kernel_h',
type=int,
help="height of the kernel in the conv2d layer")
parser.add_argument('--kernel_w', type=int,
help="width of the kernel in the conv2d layer")
parser.add_argument(
'--kernel_w', type=int, help="width of the kernel in the conv2d layer")
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
parser.add_argument('--config', type=str, help="Path to the config file.")
def add_yaml_config(config):
with open(config.config, 'rt') as f:
yaml_cfg = ruamel.yaml.safe_load(f)
cfg_vars = vars(config)
for k, v in yaml_cfg.items():
if k in cfg_vars and cfg_vars[k] is not None:
continue
cfg_vars[k] = v
return config
def load_latest_checkpoint(checkpoint_dir, rank=0):
@ -84,8 +121,12 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir, rank, model, optimizer=None,
iteration=None, file_path=None):
def load_parameters(checkpoint_dir,
rank,
model,
optimizer=None,
iteration=None,
file_path=None):
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
@ -99,7 +140,7 @@ def load_parameters(checkpoint_dir, rank, model, optimizer=None,
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
rank, file_path))
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):

View File

@ -5,22 +5,27 @@ import librosa
from .. import g2p
from ..data.sampler import SequentialSampler, RandomSampler, BatchSampler
from ..data.dataset import Dataset
from ..data.dataset import DatasetMixin
from ..data.datacargo import DataCargo
from ..data.batch import TextIDBatcher, SpecBatcher
class LJSpeech(Dataset):
class LJSpeech(DatasetMixin):
def __init__(self, root):
super(LJSpeech, self).__init__()
assert isinstance(root, (str, Path)), "root should be a string or Path object"
assert isinstance(root, (
str, Path)), "root should be a string or Path object"
self.root = root if isinstance(root, Path) else Path(root)
self.metadata = self._prepare_metadata()
def _prepare_metadata(self):
csv_path = self.root.joinpath("metadata.csv")
metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3,
names=["fname", "raw_text", "normalized_text"])
metadata = pd.read_csv(
csv_path,
sep="|",
header=None,
quoting=3,
names=["fname", "raw_text", "normalized_text"])
return metadata
def _get_example(self, metadatum):
@ -35,7 +40,9 @@ class LJSpeech(Dataset):
wav_path = self.root.joinpath("wavs", fname + ".wav")
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
wav, sample_rate = librosa.load(wav_path, sr=None) # we would rather use functor to hold its parameters
wav, sample_rate = librosa.load(
wav_path,
sr=None) # we would rather use functor to hold its parameters
trimed, _ = librosa.effects.trim(wav)
preemphasized = librosa.effects.preemphasis(trimed)
D = librosa.stft(preemphasized)
@ -50,8 +57,10 @@ class LJSpeech(Dataset):
mel = np.clip((mel - ref_db + max_db) / max_db, 1e-8, 1)
mel = np.clip((mag - ref_db + max_db) / max_db, 1e-8, 1)
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
phonemes = np.array(
g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
return (mag, mel, phonemes
) # maybe we need to implement it as a map in the future
def _batch_examples(self, minibatch):
mag_batch = []
@ -78,5 +87,3 @@ class LJSpeech(Dataset):
def __len__(self):
return len(self.metadata)

View File

@ -0,0 +1 @@
from parakeet.models.waveflow.waveflow import WaveFlow

View File

@ -5,10 +5,9 @@ import numpy as np
from paddle import fluid
from parakeet.datasets import ljspeech
from parakeet.data import dataset
from parakeet.data.batch import SpecBatcher, WavBatcher
from parakeet.data.datacargo import DataCargo
from parakeet.data.sampler import DistributedSampler, BatchSampler
from parakeet.data import SpecBatcher, WavBatcher
from parakeet.data import DataCargo, DatasetMixin
from parakeet.data import DistributedSampler, BatchSampler
from scipy.io.wavfile import read
@ -27,7 +26,7 @@ class Dataset(ljspeech.LJSpeech):
return audio
class Subset(dataset.Dataset):
class Subset(DatasetMixin):
def __init__(self, dataset, indices, valid):
self.dataset = dataset
self.indices = indices
@ -36,18 +35,18 @@ class Subset(dataset.Dataset):
def get_mel(self, audio):
spectrogram = librosa.core.stft(
audio, n_fft=self.config.fft_size,
audio,
n_fft=self.config.fft_size,
hop_length=self.config.fft_window_shift,
win_length=self.config.fft_window_size)
spectrogram_magnitude = np.abs(spectrogram)
# mel_filter_bank shape: [n_mels, 1 + n_fft/2]
mel_filter_bank = librosa.filters.mel(
sr=self.config.sample_rate,
n_fft=self.config.fft_size,
n_mels=self.config.mel_bands,
fmin=self.config.mel_fmin,
fmax=self.config.mel_fmax)
mel_filter_bank = librosa.filters.mel(sr=self.config.sample_rate,
n_fft=self.config.fft_size,
n_mels=self.config.mel_bands,
fmin=self.config.mel_fmin,
fmax=self.config.mel_fmax)
# mel shape: [n_mels, num_frames]
mel = np.dot(mel_filter_bank, spectrogram_magnitude)
@ -70,10 +69,11 @@ class Subset(dataset.Dataset):
if audio.shape[0] >= segment_length:
max_audio_start = audio.shape[0] - segment_length
audio_start = random.randint(0, max_audio_start)
audio = audio[audio_start : (audio_start + segment_length)]
audio = audio[audio_start:(audio_start + segment_length)]
else:
audio = np.pad(audio, (0, segment_length - audio.shape[0]),
mode='constant', constant_values=0)
mode='constant',
constant_values=0)
# Normalize audio to the [-1, 1] range.
audio = audio.astype(np.float32) / 32768.0
@ -112,14 +112,14 @@ class LJSpeech:
sampler = DistributedSampler(len(trainset), nranks, rank)
total_bs = config.batch_size
assert total_bs % nranks == 0
train_sampler = BatchSampler(sampler, total_bs // nranks,
drop_last=True)
train_sampler = BatchSampler(
sampler, total_bs // nranks, drop_last=True)
trainloader = DataCargo(trainset, batch_sampler=train_sampler)
trainreader = fluid.io.PyReader(capacity=50, return_list=True)
trainreader.decorate_batch_generator(trainloader, place)
self.trainloader = (data for _ in iter(int, 1)
for data in trainreader())
for data in trainreader())
# Valid dataset.
validset = Subset(ds, valid_indices, valid=True)

View File

@ -8,13 +8,18 @@ from paddle import fluid
from scipy.io.wavfile import write
import utils
from data import LJSpeech
from waveflow_modules import WaveFlowLoss, WaveFlowModule
from .data import LJSpeech
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
class WaveFlow():
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
nranks=1, tb_logger=None):
def __init__(self,
config,
checkpoint_dir,
parallel=False,
rank=0,
nranks=1,
tb_logger=None):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.parallel = parallel
@ -28,7 +33,7 @@ class WaveFlow():
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
waveflow = WaveFlowModule("waveflow", config)
waveflow = WaveFlowModule(config)
# Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32))
@ -38,13 +43,17 @@ class WaveFlow():
if training:
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=config.learning_rate)
learning_rate=config.learning_rate,
parameter_list=waveflow.parameters())
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank,
waveflow, optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
utils.load_parameters(
self.checkpoint_dir,
self.rank,
waveflow,
optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
# Data parallelism.
@ -58,9 +67,12 @@ class WaveFlow():
else:
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank, waveflow,
iteration=config.iteration,
file_path=config.checkpoint)
utils.load_parameters(
self.checkpoint_dir,
self.rank,
waveflow,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
self.waveflow = waveflow
@ -83,7 +95,8 @@ class WaveFlow():
else:
loss.backward()
self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters())
self.optimizer.minimize(
loss, parameter_list=self.waveflow.parameters())
self.waveflow.clear_gradients()
graph_time = time.time()
@ -139,7 +152,8 @@ class WaveFlow():
sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True)
if not os.path.exists(output):
os.makedirs(output)
mels_list = [mels for _, mels in self.validloader()]
if sample is not None:
@ -155,8 +169,8 @@ class WaveFlow():
audio = audio[0]
audio_time = audio.shape[0] / self.config.sample_rate
print("audio time {:.4f}, synthesis time {:.4f}".format(
audio_time, syn_time))
print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
syn_time))
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
audio = audio.numpy() * 32768.0
@ -180,8 +194,8 @@ class WaveFlow():
syn_time = time.time() - start_time
audio_time = audio.shape[1] * batch_size / self.config.sample_rate
print("audio time {:.4f}, synthesis time {:.4f}".format(
audio_time, syn_time))
print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
syn_time))
print("{} X real-time".format(audio_time / syn_time))
def save(self, iteration):

View File

@ -3,22 +3,23 @@ import itertools
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from parakeet.modules import conv, modules, weight_norm
from parakeet.modules import weight_norm
def set_param_attr(layer, c_in=1):
if isinstance(layer, (weight_norm.Conv2DTranspose, weight_norm.Conv2D)):
k = np.sqrt(1.0 / (c_in * np.prod(layer._filter_size)))
def get_param_attr(layer_type, filter_size, c_in=1):
if layer_type == "weight_norm":
k = np.sqrt(1.0 / (c_in * np.prod(filter_size)))
weight_init = fluid.initializer.UniformInitializer(low=-k, high=k)
bias_init = fluid.initializer.UniformInitializer(low=-k, high=k)
elif isinstance(layer, dg.Conv2D):
elif layer_type == "common":
weight_init = fluid.initializer.ConstantInitializer(0.0)
bias_init = fluid.initializer.ConstantInitializer(0.0)
else:
raise TypeError("Unsupported layer type.")
layer._param_attr = fluid.ParamAttr(initializer=weight_init)
layer._bias_attr = fluid.ParamAttr(initializer=bias_init)
param_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
return param_attr, bias_attr
def unfold(x, n_group):
@ -48,20 +49,23 @@ class WaveFlowLoss:
class Conditioner(dg.Layer):
def __init__(self, name_scope):
super(Conditioner, self).__init__(name_scope)
def __init__(self):
super(Conditioner, self).__init__()
upsample_factors = [16, 16]
self.upsample_conv2d = []
for s in upsample_factors:
in_channel = 1
conv_trans2d = modules.Conv2DTranspose(
self.full_name(),
param_attr, bias_attr = get_param_attr(
"weight_norm", (3, 2 * s), c_in=in_channel)
conv_trans2d = weight_norm.Conv2DTranspose(
num_channels=in_channel,
num_filters=1,
filter_size=(3, 2 * s),
padding=(1, s // 2),
stride=(1, s))
set_param_attr(conv_trans2d, c_in=in_channel)
stride=(1, s),
param_attr=param_attr,
bias_attr=bias_attr)
self.upsample_conv2d.append(conv_trans2d)
for i, layer in enumerate(self.upsample_conv2d):
@ -86,8 +90,8 @@ class Conditioner(dg.Layer):
class Flow(dg.Layer):
def __init__(self, name_scope, config):
super(Flow, self).__init__(name_scope)
def __init__(self, config):
super(Flow, self).__init__()
self.n_layers = config.n_layers
self.n_channels = config.n_channels
self.kernel_h = config.kernel_h
@ -95,27 +99,34 @@ class Flow(dg.Layer):
# Transform audio: [batch, 1, n_group, time/n_group]
# => [batch, n_channels, n_group, time/n_group]
param_attr, bias_attr = get_param_attr("weight_norm", (1, 1), c_in=1)
self.start = weight_norm.Conv2D(
self.full_name(),
num_channels=1,
num_filters=self.n_channels,
filter_size=(1, 1))
set_param_attr(self.start, c_in=1)
filter_size=(1, 1),
param_attr=param_attr,
bias_attr=bias_attr)
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
# output shape: [batch, 2, n_group, time/n_group]
param_attr, bias_attr = get_param_attr(
"common", (1, 1), c_in=self.n_channels)
self.end = dg.Conv2D(
self.full_name(),
num_channels=self.n_channels,
num_filters=2,
filter_size=(1, 1))
set_param_attr(self.end)
filter_size=(1, 1),
param_attr=param_attr,
bias_attr=bias_attr)
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
dilation_dict = {8: [1, 1, 1, 1, 1, 1, 1, 1],
16: [1, 1, 1, 1, 1, 1, 1, 1],
32: [1, 2, 4, 1, 2, 4, 1, 2],
64: [1, 2, 4, 8, 16, 1, 2, 4],
128: [1, 2, 4, 8, 16, 32, 64, 1]}
dilation_dict = {
8: [1, 1, 1, 1, 1, 1, 1, 1],
16: [1, 1, 1, 1, 1, 1, 1, 1],
32: [1, 2, 4, 1, 2, 4, 1, 2],
64: [1, 2, 4, 8, 16, 1, 2, 4],
128: [1, 2, 4, 8, 16, 32, 64, 1]
}
self.dilation_h_list = dilation_dict[config.n_group]
self.in_layers = []
@ -123,32 +134,42 @@ class Flow(dg.Layer):
self.res_skip_layers = []
for i in range(self.n_layers):
dilation_h = self.dilation_h_list[i]
dilation_w = 2 ** i
dilation_w = 2**i
param_attr, bias_attr = get_param_attr(
"weight_norm", (self.kernel_h, self.kernel_w),
c_in=self.n_channels)
in_layer = weight_norm.Conv2D(
self.full_name(),
num_channels=self.n_channels,
num_filters=2 * self.n_channels,
filter_size=(self.kernel_h, self.kernel_w),
dilation=(dilation_h, dilation_w))
set_param_attr(in_layer, c_in=self.n_channels)
dilation=(dilation_h, dilation_w),
param_attr=param_attr,
bias_attr=bias_attr)
self.in_layers.append(in_layer)
param_attr, bias_attr = get_param_attr(
"weight_norm", (1, 1), c_in=config.mel_bands)
cond_layer = weight_norm.Conv2D(
self.full_name(),
num_channels=config.mel_bands,
num_filters=2 * self.n_channels,
filter_size=(1, 1))
set_param_attr(cond_layer, c_in=config.mel_bands)
filter_size=(1, 1),
param_attr=param_attr,
bias_attr=bias_attr)
self.cond_layers.append(cond_layer)
if i < self.n_layers - 1:
res_skip_channels = 2 * self.n_channels
else:
res_skip_channels = self.n_channels
param_attr, bias_attr = get_param_attr(
"weight_norm", (1, 1), c_in=self.n_channels)
res_skip_layer = weight_norm.Conv2D(
self.full_name(),
num_channels=self.n_channels,
num_filters=res_skip_channels,
filter_size=(1, 1))
set_param_attr(res_skip_layer, c_in=self.n_channels)
filter_size=(1, 1),
param_attr=param_attr,
bias_attr=bias_attr)
self.res_skip_layers.append(res_skip_layer)
self.add_sublayer("in_layer_{}".format(i), in_layer)
@ -162,14 +183,14 @@ class Flow(dg.Layer):
for i in range(self.n_layers):
dilation_h = self.dilation_h_list[i]
dilation_w = 2 ** i
dilation_w = 2**i
# Pad height dim (n_group): causal convolution
# Pad width dim (time): dialated non-causal convolution
pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0
pad_left = pad_right = int((self.kernel_w-1) * dilation_w / 2)
audio_pad = fluid.layers.pad2d(audio,
paddings=[pad_top, pad_bottom, pad_left, pad_right])
pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2)
audio_pad = fluid.layers.pad2d(
audio, paddings=[pad_top, pad_bottom, pad_left, pad_right])
hidden = self.in_layers[i](audio_pad)
cond_hidden = self.cond_layers[i](mel)
@ -196,7 +217,7 @@ class Flow(dg.Layer):
for i in range(self.n_layers):
dilation_h = self.dilation_h_list[i]
dilation_w = 2 ** i
dilation_w = 2**i
state_size = dilation_h * (self.kernel_h - 1)
queue = queues[i]
@ -206,7 +227,7 @@ class Flow(dg.Layer):
queue.append(fluid.layers.zeros_like(audio))
state = queue[0:state_size]
state = fluid.layers.concat([*state, audio], axis=2)
state = fluid.layers.concat(state + [audio], axis=2)
queue.pop(0)
queue.append(audio)
@ -214,10 +235,10 @@ class Flow(dg.Layer):
# Pad height dim (n_group): causal convolution
# Pad width dim (time): dialated non-causal convolution
pad_top, pad_bottom = 0, 0
pad_left = int((self.kernel_w-1) * dilation_w / 2)
pad_right = int((self.kernel_w-1) * dilation_w / 2)
state = fluid.layers.pad2d(state,
paddings=[pad_top, pad_bottom, pad_left, pad_right])
pad_left = int((self.kernel_w - 1) * dilation_w / 2)
pad_right = int((self.kernel_w - 1) * dilation_w / 2)
state = fluid.layers.pad2d(
state, paddings=[pad_top, pad_bottom, pad_left, pad_right])
hidden = self.in_layers[i](state)
cond_hidden = self.cond_layers[i](mel)
@ -241,18 +262,18 @@ class Flow(dg.Layer):
class WaveFlowModule(dg.Layer):
def __init__(self, name_scope, config):
super(WaveFlowModule, self).__init__(name_scope)
def __init__(self, config):
super(WaveFlowModule, self).__init__()
self.n_flows = config.n_flows
self.n_group = config.n_group
self.n_layers = config.n_layers
assert self.n_group % 2 == 0
assert self.n_flows % 2 == 0
self.conditioner = Conditioner(self.full_name())
self.conditioner = Conditioner()
self.flows = []
for i in range(self.n_flows):
flow = Flow(self.full_name(), config)
flow = Flow(config)
self.flows.append(flow)
self.add_sublayer("flow_{}".format(i), flow)
@ -284,7 +305,6 @@ class WaveFlowModule(dg.Layer):
audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1])
# [bs, 1, n_group, time/n_group]
audio = fluid.layers.unsqueeze(audio, 1)
log_s_list = []
for i in range(self.n_flows):
inputs = audio[:, :, :-1, :]
@ -305,7 +325,6 @@ class WaveFlowModule(dg.Layer):
mel = fluid.layers.stack(mel_slices, axis=2)
z = fluid.layers.squeeze(audio, [1])
return z, log_s_list
def synthesize(self, mel, sigma=1.0):
@ -331,7 +350,7 @@ class WaveFlowModule(dg.Layer):
for h in range(1, self.n_group):
inputs = audio_h
conds = mel[:, :, h:(h+1), :]
conds = mel[:, :, h:(h + 1), :]
outputs = self.flows[i].infer(inputs, conds, queues)
log_s = outputs[:, 0:1, :, :]

View File

@ -40,8 +40,8 @@ def norm_except(param, dim, power):
def compute_weight(v, g, dim, power):
assert len(g.shape) == 1, "magnitude should be a vector"
v_normalized = F.elementwise_div(v, (norm_except(v, dim, power) + 1e-12),
axis=dim)
v_normalized = F.elementwise_div(
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
weight = F.elementwise_mul(v_normalized, g, axis=dim)
return weight
@ -63,20 +63,21 @@ class WeightNormWrapper(dg.Layer):
original_weight = getattr(layer, param_name)
self.add_parameter(
w_v,
self.create_parameter(shape=original_weight.shape,
dtype=original_weight.dtype))
self.create_parameter(
shape=original_weight.shape, dtype=original_weight.dtype))
F.assign(original_weight, getattr(self, w_v))
delattr(layer, param_name)
temp = norm_except(getattr(self, w_v), self.dim, self.power)
self.add_parameter(
w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype))
w_g, self.create_parameter(
shape=temp.shape, dtype=temp.dtype))
F.assign(temp, getattr(self, w_g))
# also set this when setting up
setattr(
self.layer, self.param_name,
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,
self.power))
setattr(self.layer, self.param_name,
compute_weight(
getattr(self, w_v),
getattr(self, w_g), self.dim, self.power))
self.weigth_norm_applied = True
@ -84,10 +85,10 @@ class WeightNormWrapper(dg.Layer):
def hook(self):
w_v = self.param_name + "_v"
w_g = self.param_name + "_g"
setattr(
self.layer, self.param_name,
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,
self.power))
setattr(self.layer, self.param_name,
compute_weight(
getattr(self, w_v),
getattr(self, w_g), self.dim, self.power))
def remove_weight_norm(self):
self.hook()