Merge branch 'master' of upstream.
This commit is contained in:
commit
02f742d914
|
@ -3,8 +3,8 @@ audio:
|
|||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
|
@ -52,6 +52,12 @@ def add_config_options_to_parser(parser):
|
|||
type=int,
|
||||
default=0,
|
||||
help="use data parallel or not during training.")
|
||||
parser.add_argument(
|
||||
'--alpha',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The hyperparameter to determine the length of the expanded sequence \
|
||||
mel, thereby controlling the voice speed.")
|
||||
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
|
|
|
@ -24,6 +24,7 @@ import paddle.fluid.dygraph as dg
|
|||
from parakeet.g2p.en import text_to_sequence
|
||||
from parakeet import audio
|
||||
from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
|
@ -59,12 +60,26 @@ def synthesis(text_input, args):
|
|||
model.eval()
|
||||
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
||||
text = np.expand_dims(text, axis=0)
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||
pos_text = np.expand_dims(pos_text, axis=0)
|
||||
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
|
||||
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text,
|
||||
text).astype(np.float32)
|
||||
|
||||
text = dg.to_variable(text)
|
||||
pos_text = dg.to_variable(pos_text)
|
||||
enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
|
||||
enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
|
||||
|
||||
mel_output, mel_output_postnet = model(
|
||||
text, pos_text, alpha=args.alpha)
|
||||
text,
|
||||
pos_text,
|
||||
alpha=args.alpha,
|
||||
enc_non_pad_mask=enc_non_pad_mask,
|
||||
enc_slf_attn_mask=enc_slf_attn_mask,
|
||||
dec_non_pad_mask=None,
|
||||
dec_slf_attn_mask=None)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
|
|
|
@ -21,6 +21,7 @@ from parse import add_config_options_to_parser
|
|||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from tqdm import tqdm
|
||||
from matplotlib import cm
|
||||
from collections import OrderedDict
|
||||
from tensorboardX import SummaryWriter
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
@ -66,12 +67,12 @@ def main(args):
|
|||
|
||||
with dg.guard(place):
|
||||
with fluid.unique_name.guard():
|
||||
transformerTTS = TransformerTTS(cfg)
|
||||
transformer_tts = TransformerTTS(cfg)
|
||||
model_dict, _ = load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.transtts_path, "transformer"))
|
||||
transformerTTS.set_dict(model_dict)
|
||||
transformerTTS.eval()
|
||||
transformer_tts.set_dict(model_dict)
|
||||
transformer_tts.eval()
|
||||
|
||||
model = FastSpeech(cfg)
|
||||
model.train()
|
||||
|
@ -100,13 +101,33 @@ def main(args):
|
|||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens = data
|
||||
(character, mel, mel_input, pos_text, pos_mel, text_length,
|
||||
mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask,
|
||||
enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data
|
||||
|
||||
_, _, attn_probs, _, _, _ = transformerTTS(
|
||||
character, mel_input, pos_text, pos_mel)
|
||||
alignment = dg.to_variable(
|
||||
get_alignment(attn_probs, mel_lens, cfg[
|
||||
'transformer_head'])).astype(np.float32)
|
||||
_, _, attn_probs, _, _, _ = transformer_tts(
|
||||
character,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask,
|
||||
enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask,
|
||||
dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
alignment, max_attn = get_alignment(attn_probs, mel_lens,
|
||||
cfg['transformer_head'])
|
||||
alignment = dg.to_variable(alignment).astype(np.float32)
|
||||
|
||||
if local_rank == 0 and global_step % 5 == 1:
|
||||
x = np.uint8(
|
||||
cm.viridis(max_attn[8, :mel_lens.numpy()[8]]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
0,
|
||||
dataformats="HWC")
|
||||
|
||||
global_step += 1
|
||||
|
||||
|
@ -115,7 +136,11 @@ def main(args):
|
|||
character,
|
||||
pos_text,
|
||||
mel_pos=pos_mel,
|
||||
length_target=alignment)
|
||||
length_target=alignment,
|
||||
enc_non_pad_mask=enc_query_mask,
|
||||
enc_slf_attn_mask=enc_slf_mask,
|
||||
dec_non_pad_mask=dec_query_slf_mask,
|
||||
dec_slf_attn_mask=dec_slf_mask)
|
||||
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
|
||||
mel_loss = layers.mse_loss(mel_output, mel)
|
||||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --fastspeech_step
|
||||
CUDA_VISIBLE_DEVICES=0\
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python -u train.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
|
|
|
@ -8,4 +8,7 @@ audio:
|
|||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
outputs_per_step: 1
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
|
@ -23,7 +23,8 @@ from parakeet import audio
|
|||
from parakeet.data.sampler import *
|
||||
from parakeet.data.datacargo import DataCargo
|
||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
||||
from parakeet.data.dataset import DatasetMixin, TransformDataset
|
||||
from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
|
||||
|
||||
class LJSpeechLoader:
|
||||
|
@ -40,6 +41,8 @@ class LJSpeechLoader:
|
|||
metadata = LJSpeechMetaData(LJSPEECH_ROOT)
|
||||
transformer = LJSpeech(config)
|
||||
dataset = TransformDataset(metadata, transformer)
|
||||
dataset = CacheDataset(dataset)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
len(metadata), nranks, rank, shuffle=shuffle)
|
||||
|
||||
|
@ -196,8 +199,18 @@ def batch_examples(batch):
|
|||
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||
mel_inputs = np.transpose(
|
||||
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||
enc_slf_mask = get_attn_key_pad_mask(pos_texts, texts).astype(np.float32)
|
||||
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
|
||||
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,
|
||||
mel_inputs).astype(np.float32)
|
||||
enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0],
|
||||
mel_inputs).astype(np.float32)
|
||||
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
|
||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens),
|
||||
np.array(mel_lens))
|
||||
np.array(mel_lens), enc_slf_mask, enc_query_mask, dec_slf_mask,
|
||||
enc_dec_mask, dec_query_slf_mask, dec_query_mask)
|
||||
|
||||
|
||||
def batch_examples_vocoder(batch):
|
||||
|
|
|
@ -16,6 +16,7 @@ from scipy.io.wavfile import write
|
|||
from parakeet.g2p.en import text_to_sequence
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from matplotlib import cm
|
||||
from tensorboardX import SummaryWriter
|
||||
from ruamel import yaml
|
||||
import paddle.fluid as fluid
|
||||
|
@ -25,6 +26,7 @@ import argparse
|
|||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from collections import OrderedDict
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet import audio
|
||||
from parakeet.models.transformer_tts.vocoder import Vocoder
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
|
@ -78,14 +80,18 @@ def synthesis(text_input, args):
|
|||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||
|
||||
pbar = tqdm(range(args.max_len))
|
||||
|
||||
for i in pbar:
|
||||
dec_slf_mask = get_triu_tensor(
|
||||
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
||||
dec_slf_mask = fluid.layers.cast(
|
||||
dg.to_variable(dec_slf_mask == 0), np.float32)
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
text, mel_input, pos_text, pos_mel)
|
||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
||||
mel_input = fluid.layers.concat(
|
||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||
|
||||
mag_pred = model_vocoder(postnet_pred)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
|
@ -111,6 +117,33 @@ def synthesis(text_input, args):
|
|||
wav = _ljspeech_processor.inv_spectrogram(
|
||||
fluid.layers.transpose(
|
||||
fluid.layers.squeeze(mag_pred, [0]), [1, 0]).numpy())
|
||||
global_step = 0
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
|
||||
if not os.path.exists(args.sample_path):
|
||||
os.mkdir(args.sample_path)
|
||||
|
@ -124,4 +157,6 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
synthesis("Transformer model is so fast!", args)
|
||||
synthesis(
|
||||
"They emphasized the necessity that the information now being furnished be handled with judgment and care.",
|
||||
args)
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
# train model
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u synthesis.py \
|
||||
--max_len=50 \
|
||||
--max_len=600 \
|
||||
--transformer_step=160000 \
|
||||
--vocoder_step=70000 \
|
||||
--use_gpu=1
|
||||
--vocoder_step=90000 \
|
||||
--use_gpu=1 \
|
||||
--checkpoint_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--sample_path='./sample' \
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import os
|
||||
from tqdm import tqdm
|
||||
from tensorboardX import SummaryWriter
|
||||
from pathlib import Path
|
||||
#from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
|
@ -89,21 +89,31 @@ def main(args):
|
|||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, _ = data
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data
|
||||
|
||||
global_step += 1
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
character, mel_input, pos_text, pos_mel)
|
||||
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
character,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask,
|
||||
enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask,
|
||||
dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
|
||||
mel_loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
||||
post_mel_loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||
loss = mel_loss + post_mel_loss
|
||||
|
||||
# Note: When used stop token loss the learning did not work.
|
||||
if args.stop_token:
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
stop_loss = cross_entropy(stop_preds, label)
|
||||
loss = loss + stop_loss
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --transformer_step
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
export CUDA_VISIBLE_DEVICES=2
|
||||
python -u train_transformer.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
|
|
|
@ -4,7 +4,7 @@ PaddlePaddle dynamic graph implementation of [WaveFlow: A Compact Flow-based Mod
|
|||
|
||||
- WaveFlow can synthesize 22.05 kHz high-fidelity speech around 40x faster than real-time on a Nvidia V100 GPU without engineered inference kernels, which is faster than [WaveGlow] (https://github.com/NVIDIA/waveglow) and serveral orders of magnitude faster than WaveNet.
|
||||
- WaveFlow is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M) and comparable to WaveNet (4.6M).
|
||||
- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development.
|
||||
- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development.
|
||||
|
||||
## Project Structure
|
||||
```text
|
||||
|
@ -99,7 +99,7 @@ python -u synthesis.py \
|
|||
--sigma=1.0
|
||||
```
|
||||
|
||||
In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset.
|
||||
In this example, `--output` specifies where to save the synthesized audios and `--sample` (<16) specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
|
|
|
@ -109,6 +109,16 @@ def add_yaml_config(config):
|
|||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int, optional): the rank of the process in multi-process setting.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
int: the latest iteration number.
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Create checkpoint index file if not exist.
|
||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
||||
|
@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
|
|||
|
||||
|
||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||
"""Save the iteration number of the latest model to be checkpointed.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Update the latest checkpoint index.
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
|
@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir,
|
|||
iteration=None,
|
||||
file_path=None,
|
||||
dtype="float32"):
|
||||
"""Load a specific model checkpoint from disk.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int): the rank of the process in multi-process setting.
|
||||
model (obj): model to load parameters.
|
||||
optimizer (obj, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
iteration (int, optional): if specified, load the specific checkpoint,
|
||||
if not specified, load the latest one. Defaults to None.
|
||||
file_path (str, optional): if specified, load the checkpoint
|
||||
stored in the file_path. Defaults to None.
|
||||
dtype (str, optional): precision of the model parameters.
|
||||
Defaults to float32.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if file_path is None:
|
||||
if iteration is None:
|
||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||
|
@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir,
|
|||
|
||||
|
||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
model (obj): model to be checkpointed.
|
||||
optimizer (obj, optional): optimizer to be checkpointed.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
model_dict = model.state_dict()
|
||||
dg.save_dygraph(model_dict, file_path)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import six
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class DatasetMixin(object):
|
||||
|
|
|
@ -32,6 +32,7 @@ class Decoder(dg.Layer):
|
|||
super(Decoder, self).__init__()
|
||||
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
self.pos_inp = get_sinusoid_encoding_table(
|
||||
n_position, d_model, padding_idx=0)
|
||||
self.position_enc = dg.Embedding(
|
||||
|
@ -55,7 +56,7 @@ class Decoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, enc_seq, enc_pos):
|
||||
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Decoder layer of FastSpeech.
|
||||
|
||||
|
@ -69,10 +70,7 @@ class Decoder(dg.Layer):
|
|||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
|
||||
# -- Prepare masks
|
||||
slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
|
||||
non_pad_mask = get_non_pad_mask(enc_pos)
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
||||
# -- Forward
|
||||
dec_output = enc_seq + self.position_enc(enc_pos)
|
||||
|
|
|
@ -32,14 +32,17 @@ class Encoder(dg.Layer):
|
|||
dropout=0.1):
|
||||
super(Encoder, self).__init__()
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
|
||||
self.src_word_emb = dg.Embedding(
|
||||
size=[n_src_vocab, d_model], padding_idx=0)
|
||||
size=[n_src_vocab, d_model],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.initializer.Normal(
|
||||
loc=0.0, scale=1.0))
|
||||
self.pos_inp = get_sinusoid_encoding_table(
|
||||
n_position, d_model, padding_idx=0)
|
||||
self.position_enc = dg.Embedding(
|
||||
size=[n_position, d_model],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
self.pos_inp),
|
||||
|
@ -58,7 +61,7 @@ class Encoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, character, text_pos):
|
||||
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Encoder layer of FastSpeech.
|
||||
|
||||
|
@ -74,10 +77,7 @@ class Encoder(dg.Layer):
|
|||
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
# -- prepare masks
|
||||
# shape character (N, T)
|
||||
slf_attn_mask = get_attn_key_pad_mask(seq_k=character, seq_q=character)
|
||||
non_pad_mask = get_non_pad_mask(character)
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
||||
# -- Forward
|
||||
enc_output = self.src_word_emb(character) + self.position_enc(
|
||||
|
@ -90,4 +90,4 @@ class Encoder(dg.Layer):
|
|||
slf_attn_mask=slf_attn_mask)
|
||||
enc_slf_attn_list += [enc_slf_attn]
|
||||
|
||||
return enc_output, non_pad_mask, enc_slf_attn_list
|
||||
return enc_output, enc_slf_attn_list
|
||||
|
|
|
@ -12,9 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.g2p.text.symbols import symbols
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
||||
from parakeet.models.fastspeech.length_regulator import LengthRegulator
|
||||
from parakeet.models.fastspeech.encoder import Encoder
|
||||
|
@ -78,6 +80,10 @@ class FastSpeech(dg.Layer):
|
|||
def forward(self,
|
||||
character,
|
||||
text_pos,
|
||||
enc_non_pad_mask,
|
||||
dec_non_pad_mask,
|
||||
enc_slf_attn_mask=None,
|
||||
dec_slf_attn_mask=None,
|
||||
mel_pos=None,
|
||||
length_target=None,
|
||||
alpha=1.0):
|
||||
|
@ -106,14 +112,20 @@ class FastSpeech(dg.Layer):
|
|||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(
|
||||
character, text_pos)
|
||||
encoder_output, enc_slf_attn_list = self.encoder(
|
||||
character,
|
||||
text_pos,
|
||||
enc_non_pad_mask,
|
||||
slf_attn_mask=enc_slf_attn_mask)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
|
||||
length_regulator_output, duration_predictor_output = self.length_regulator(
|
||||
encoder_output, target=length_target, alpha=alpha)
|
||||
decoder_output, dec_slf_attn_list = self.decoder(
|
||||
length_regulator_output, mel_pos)
|
||||
length_regulator_output,
|
||||
mel_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=dec_slf_attn_mask)
|
||||
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
@ -122,8 +134,18 @@ class FastSpeech(dg.Layer):
|
|||
else:
|
||||
length_regulator_output, decoder_pos = self.length_regulator(
|
||||
encoder_output, alpha=alpha)
|
||||
decoder_output, _ = self.decoder(length_regulator_output,
|
||||
decoder_pos)
|
||||
slf_attn_mask = get_triu_tensor(
|
||||
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
|
||||
slf_attn_mask = fluid.layers.cast(
|
||||
dg.to_variable(slf_attn_mask == 0), np.float32)
|
||||
slf_attn_mask = dg.to_variable(slf_attn_mask)
|
||||
dec_non_pad_mask = fluid.layers.unsqueeze(
|
||||
(decoder_pos != 0).astype(np.float32), [-1])
|
||||
decoder_output, _ = self.decoder(
|
||||
length_regulator_output,
|
||||
decoder_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=slf_attn_mask)
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class FFTBlock(dg.Layer):
|
|||
padding=padding,
|
||||
dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
|
||||
def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Feed Forward Transformer block in FastSpeech.
|
||||
|
||||
|
@ -63,6 +63,7 @@ class FFTBlock(dg.Layer):
|
|||
"""
|
||||
output, slf_attn = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
|
||||
output *= non_pad_mask
|
||||
|
||||
output = self.pos_ffn(output)
|
||||
|
|
|
@ -146,11 +146,17 @@ class DurationPredictor(dg.Layer):
|
|||
out = layers.transpose(encoder_output, [0, 2, 1])
|
||||
out = self.conv1(out)
|
||||
out = layers.transpose(out, [0, 2, 1])
|
||||
out = layers.dropout(layers.relu(self.layer_norm1(out)), self.dropout)
|
||||
out = layers.dropout(
|
||||
layers.relu(self.layer_norm1(out)),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
out = layers.transpose(out, [0, 2, 1])
|
||||
out = self.conv2(out)
|
||||
out = layers.transpose(out, [0, 2, 1])
|
||||
out = layers.dropout(layers.relu(self.layer_norm2(out)), self.dropout)
|
||||
out = layers.dropout(
|
||||
layers.relu(self.layer_norm2(out)),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
out = layers.relu(self.linear(out))
|
||||
out = layers.squeeze(out, axes=[-1])
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@ def get_alignment(attn_probs, mel_lens, n_head):
|
|||
max_F = 0
|
||||
assert attn_probs[0].shape[0] % n_head == 0
|
||||
batch_size = int(attn_probs[0].shape[0] // n_head)
|
||||
#max_attn = attn_probs[0].numpy()[0,batch_size]
|
||||
for i in range(len(attn_probs)):
|
||||
multi_attn = attn_probs[i].numpy()
|
||||
for j in range(n_head):
|
||||
|
@ -28,7 +27,7 @@ def get_alignment(attn_probs, mel_lens, n_head):
|
|||
max_F = F
|
||||
max_attn = attn
|
||||
alignment = compute_duration(max_attn, mel_lens)
|
||||
return alignment
|
||||
return alignment, max_attn
|
||||
|
||||
|
||||
def score_F(attn):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import math
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.modules.utils import *
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.modules.multihead_attention import MultiheadAttention
|
||||
from parakeet.modules.ffn import PositionwiseFeedForward
|
||||
from parakeet.models.transformer_tts.prenet import PreNet
|
||||
|
@ -25,6 +25,7 @@ class Decoder(dg.Layer):
|
|||
def __init__(self, num_hidden, config, num_head=4):
|
||||
super(Decoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
param = fluid.ParamAttr()
|
||||
self.alpha = self.create_parameter(
|
||||
shape=(1, ),
|
||||
|
@ -98,30 +99,29 @@ class Decoder(dg.Layer):
|
|||
outputs_per_step=config['audio']['outputs_per_step'],
|
||||
use_cudnn=True)
|
||||
|
||||
def forward(self, key, value, query, c_mask, positional):
|
||||
def forward(self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
positional,
|
||||
mask,
|
||||
m_mask=None,
|
||||
m_self_mask=None,
|
||||
zero_mask=None):
|
||||
|
||||
# get decoder mask with triangular matrix
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
m_mask = get_non_pad_mask(positional)
|
||||
mask = get_attn_key_pad_mask((positional == 0).astype(np.float32),
|
||||
query)
|
||||
triu_tensor = dg.to_variable(
|
||||
get_triu_tensor(query.numpy(), query.numpy())).astype(
|
||||
np.float32)
|
||||
mask = mask + triu_tensor
|
||||
mask = fluid.layers.cast(mask == 0, np.float32)
|
||||
m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]])
|
||||
m_self_mask = layers.expand(m_self_mask,
|
||||
[self.num_head, 1, query.shape[1]])
|
||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
||||
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
|
||||
|
||||
# (batch_size, decoder_len, encoder_len)
|
||||
zero_mask = get_attn_key_pad_mask(
|
||||
layers.squeeze(c_mask, [-1]), query)
|
||||
else:
|
||||
mask = get_triu_tensor(query.numpy(),
|
||||
query.numpy()).astype(np.float32)
|
||||
mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32)
|
||||
m_mask, zero_mask = None, None
|
||||
m_mask, m_self_mask, zero_mask = None, None, None
|
||||
|
||||
# Decoder pre-network
|
||||
# Decoder pre-network
|
||||
query = self.decoder_prenet(query)
|
||||
|
||||
# Centered position
|
||||
|
@ -132,7 +132,8 @@ class Decoder(dg.Layer):
|
|||
query = positional * self.alpha + query
|
||||
|
||||
#positional dropout
|
||||
query = fluid.layers.dropout(query, 0.1)
|
||||
query = fluid.layers.dropout(
|
||||
query, 0.1, dropout_implementation='upscale_in_train')
|
||||
|
||||
# Attention decoder-decoder, encoder-decoder
|
||||
selfattn_list = list()
|
||||
|
@ -141,12 +142,13 @@ class Decoder(dg.Layer):
|
|||
for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers,
|
||||
self.ffns):
|
||||
query, attn_dec = selfattn(
|
||||
query, query, query, mask=mask, query_mask=m_mask)
|
||||
query, query, query, mask=mask, query_mask=m_self_mask)
|
||||
query, attn_dot = attn(
|
||||
key, value, query, mask=zero_mask, query_mask=m_mask)
|
||||
query = ffn(query)
|
||||
selfattn_list.append(attn_dec)
|
||||
attn_list.append(attn_dot)
|
||||
|
||||
# Mel linear projection
|
||||
mel_out = self.mel_linear(query)
|
||||
# Post Mel Network
|
||||
|
|
|
@ -23,6 +23,7 @@ class Encoder(dg.Layer):
|
|||
def __init__(self, embedding_size, num_hidden, num_head=4):
|
||||
super(Encoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
param = fluid.ParamAttr(initializer=fluid.initializer.Constant(
|
||||
value=1.0))
|
||||
self.alpha = self.create_parameter(
|
||||
|
@ -31,7 +32,6 @@ class Encoder(dg.Layer):
|
|||
1024, self.num_hidden, padding_idx=0)
|
||||
self.pos_emb = dg.Embedding(
|
||||
size=[1024, num_hidden],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
self.pos_inp),
|
||||
|
@ -56,13 +56,15 @@ class Encoder(dg.Layer):
|
|||
for i, layer in enumerate(self.ffns):
|
||||
self.add_sublayer("ffns_{}".format(i), layer)
|
||||
|
||||
def forward(self, x, positional):
|
||||
def forward(self, x, positional, mask=None, query_mask=None):
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
query_mask = get_non_pad_mask(positional)
|
||||
mask = get_attn_key_pad_mask(positional, x)
|
||||
seq_len_key = x.shape[1]
|
||||
query_mask = layers.expand(query_mask,
|
||||
[self.num_head, 1, seq_len_key])
|
||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
||||
else:
|
||||
query_mask, mask = None, None
|
||||
|
||||
# Encoder pre_network
|
||||
x = self.encoder_prenet(x) #(N,T,C)
|
||||
|
||||
|
@ -72,7 +74,7 @@ class Encoder(dg.Layer):
|
|||
x = positional * self.alpha + x #(N, T, C)
|
||||
|
||||
# Positional dropout
|
||||
x = layers.dropout(x, 0.1)
|
||||
x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')
|
||||
|
||||
# Self attention encoder
|
||||
attentions = list()
|
||||
|
@ -81,4 +83,4 @@ class Encoder(dg.Layer):
|
|||
x = ffn(x)
|
||||
attentions.append(attention)
|
||||
|
||||
return x, query_mask, attentions
|
||||
return x, attentions
|
||||
|
|
|
@ -27,7 +27,10 @@ class EncoderPrenet(dg.Layer):
|
|||
self.num_hidden = num_hidden
|
||||
self.use_cudnn = use_cudnn
|
||||
self.embedding = dg.Embedding(
|
||||
size=[len(symbols), embedding_size], padding_idx=None)
|
||||
size=[len(symbols), embedding_size],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.initializer.Normal(
|
||||
loc=0.0, scale=1.0))
|
||||
self.conv_list = []
|
||||
k = math.sqrt(1 / embedding_size)
|
||||
self.conv_list.append(
|
||||
|
@ -78,10 +81,14 @@ class EncoderPrenet(dg.Layer):
|
|||
low=-k, high=k)))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.embedding(x) #(batch_size, seq_len, embending_size)
|
||||
x = layers.transpose(x, [0, 2, 1])
|
||||
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
|
||||
x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2)
|
||||
x = layers.dropout(
|
||||
layers.relu(batch_norm(conv(x))),
|
||||
0.2,
|
||||
dropout_implementation='upscale_in_train')
|
||||
x = layers.transpose(x, [0, 2, 1]) #(N,T,C)
|
||||
x = self.projection(x)
|
||||
|
||||
|
|
|
@ -108,11 +108,16 @@ class PostConvNet(dg.Layer):
|
|||
conv = self.conv_list[i]
|
||||
|
||||
input = layers.dropout(
|
||||
layers.tanh(batch_norm(conv(input)[:, :, :len])), self.dropout)
|
||||
layers.tanh(batch_norm(conv(input)[:, :, :len])),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
conv = self.conv_list[self.num_conv - 1]
|
||||
input = conv(input)[:, :, :len]
|
||||
if self.batchnorm_last:
|
||||
batch_norm = self.batch_norm_list[self.num_conv - 1]
|
||||
input = layers.dropout(batch_norm(input), self.dropout)
|
||||
input = layers.dropout(
|
||||
batch_norm(input),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
output = layers.transpose(input, [0, 2, 1])
|
||||
return output
|
||||
|
|
|
@ -56,6 +56,12 @@ class PreNet(dg.Layer):
|
|||
Returns:
|
||||
x (Variable), Shape(B, T, C), the result after pernet.
|
||||
"""
|
||||
x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate)
|
||||
x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate)
|
||||
x = layers.dropout(
|
||||
layers.relu(self.linear1(x)),
|
||||
self.dropout_rate,
|
||||
dropout_implementation='upscale_in_train')
|
||||
x = layers.dropout(
|
||||
layers.relu(self.linear2(x)),
|
||||
self.dropout_rate,
|
||||
dropout_implementation='upscale_in_train')
|
||||
return x
|
||||
|
|
|
@ -24,11 +24,29 @@ class TransformerTTS(dg.Layer):
|
|||
self.decoder = Decoder(config['hidden_size'], config)
|
||||
self.config = config
|
||||
|
||||
def forward(self, characters, mel_input, pos_text, pos_mel):
|
||||
|
||||
key, c_mask, attns_enc = self.encoder(characters, pos_text)
|
||||
def forward(self,
|
||||
characters,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask,
|
||||
enc_slf_mask=None,
|
||||
enc_query_mask=None,
|
||||
enc_dec_mask=None,
|
||||
dec_query_slf_mask=None,
|
||||
dec_query_mask=None):
|
||||
key, attns_enc = self.encoder(
|
||||
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
||||
|
||||
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(
|
||||
key, key, mel_input, c_mask, pos_mel)
|
||||
key,
|
||||
key,
|
||||
mel_input,
|
||||
pos_mel,
|
||||
mask=dec_slf_mask,
|
||||
zero_mask=enc_dec_mask,
|
||||
m_self_mask=dec_query_slf_mask,
|
||||
m_mask=dec_query_mask)
|
||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
||||
|
||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
||||
|
|
|
@ -51,7 +51,9 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
|||
|
||||
|
||||
def get_non_pad_mask(seq):
|
||||
return layers.unsqueeze((seq != 0).astype(np.float32), [-1])
|
||||
mask = (seq != 0).astype(np.float32)
|
||||
mask = np.expand_dims(mask, axis=-1)
|
||||
return mask
|
||||
|
||||
|
||||
def get_attn_key_pad_mask(seq_k, seq_q):
|
||||
|
@ -60,8 +62,22 @@ def get_attn_key_pad_mask(seq_k, seq_q):
|
|||
# Expand to fit the shape of key query attention matrix.
|
||||
len_q = seq_q.shape[1]
|
||||
padding_mask = (seq_k != 0).astype(np.float32)
|
||||
padding_mask = layers.expand(
|
||||
layers.unsqueeze(padding_mask, [1]), [1, len_q, 1])
|
||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
||||
padding_mask = padding_mask.repeat([len_q], axis=1)
|
||||
padding_mask = (padding_mask == 0).astype(np.float32) * (-2**32 + 1)
|
||||
return padding_mask
|
||||
|
||||
|
||||
def get_dec_attn_key_pad_mask(seq_k, seq_q):
|
||||
''' For masking out the padding part of key sequence. '''
|
||||
|
||||
# Expand to fit the shape of key query attention matrix.
|
||||
len_q = seq_q.shape[1]
|
||||
padding_mask = (seq_k == 0).astype(np.float32)
|
||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
||||
triu_tensor = get_triu_tensor(seq_q, seq_q)
|
||||
padding_mask = padding_mask.repeat([len_q], axis=1) + triu_tensor
|
||||
padding_mask = (padding_mask != 0).astype(np.float32) * (-2**32 + 1)
|
||||
return padding_mask
|
||||
|
||||
|
||||
|
|
|
@ -80,6 +80,7 @@ class Subset(DatasetMixin):
|
|||
# whole audio for valid set
|
||||
pass
|
||||
else:
|
||||
# Randomly crop segment_length from audios in the training set.
|
||||
# audio shape: [len]
|
||||
if audio.shape[0] >= segment_length:
|
||||
max_audio_start = audio.shape[0] - segment_length
|
||||
|
|
|
@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
|||
|
||||
|
||||
class WaveFlow():
|
||||
"""Wrapper class of WaveFlow model that supports multiple APIs.
|
||||
|
||||
This module provides APIs for model building, training, validation,
|
||||
inference, benchmarking, and saving.
|
||||
|
||||
Args:
|
||||
config (obj): config info.
|
||||
checkpoint_dir (str): path for checkpointing.
|
||||
parallel (bool, optional): whether use multiple GPUs for training.
|
||||
Defaults to False.
|
||||
rank (int, optional): the rank of the process in a multi-process
|
||||
scenario. Defaults to 0.
|
||||
nranks (int, optional): the total number of processes. Defaults to 1.
|
||||
tb_logger (obj, optional): logger to visualize metrics.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
WaveFlow
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
|
@ -44,6 +63,15 @@ class WaveFlow():
|
|||
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||
|
||||
def build(self, training=True):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
training (bool, optional): Whether the model is built for training or inference.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
config = self.config
|
||||
dataset = LJSpeech(config, self.nranks, self.rank)
|
||||
self.trainloader = dataset.trainloader
|
||||
|
@ -99,6 +127,14 @@ class WaveFlow():
|
|||
self.waveflow = waveflow
|
||||
|
||||
def train_step(self, iteration):
|
||||
"""Train the model for one step.
|
||||
|
||||
Args:
|
||||
iteration (int): current iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.train()
|
||||
|
||||
start_time = time.time()
|
||||
|
@ -135,6 +171,14 @@ class WaveFlow():
|
|||
|
||||
@dg.no_grad
|
||||
def valid_step(self, iteration):
|
||||
"""Run the model on the validation dataset.
|
||||
|
||||
Args:
|
||||
iteration (int): current iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.eval()
|
||||
tb = self.tb_logger
|
||||
|
||||
|
@ -167,6 +211,14 @@ class WaveFlow():
|
|||
|
||||
@dg.no_grad
|
||||
def infer(self, iteration):
|
||||
"""Run the model to synthesize audios.
|
||||
|
||||
Args:
|
||||
iteration (int): iteration number of the loaded checkpoint.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.eval()
|
||||
|
||||
config = self.config
|
||||
|
@ -179,10 +231,13 @@ class WaveFlow():
|
|||
mels_list = [mels for _, mels in self.validloader()]
|
||||
if sample is not None:
|
||||
mels_list = [mels_list[sample]]
|
||||
else:
|
||||
sample = 0
|
||||
|
||||
for sample, mel in enumerate(mels_list):
|
||||
filename = "{}/valid_{}.wav".format(output, sample)
|
||||
print("Synthesize sample {}, save as {}".format(sample, filename))
|
||||
for idx, mel in enumerate(mels_list):
|
||||
abs_idx = sample + idx
|
||||
filename = "{}/valid_{}.wav".format(output, abs_idx)
|
||||
print("Synthesize sample {}, save as {}".format(abs_idx, filename))
|
||||
|
||||
start_time = time.time()
|
||||
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
|
||||
|
@ -200,6 +255,14 @@ class WaveFlow():
|
|||
|
||||
@dg.no_grad
|
||||
def benchmark(self):
|
||||
"""Run the model to benchmark synthesis speed.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.eval()
|
||||
|
||||
mels_list = [mels for _, mels in self.validloader()]
|
||||
|
@ -220,6 +283,14 @@ class WaveFlow():
|
|||
print("{} X real-time".format(audio_time / syn_time))
|
||||
|
||||
def save(self, iteration):
|
||||
"""Save model checkpoint.
|
||||
|
||||
Args:
|
||||
iteration (int): iteration number of the model to be saved.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||
self.waveflow, self.optimizer)
|
||||
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|
||||
|
|
|
@ -293,6 +293,14 @@ class Flow(dg.Layer):
|
|||
|
||||
|
||||
class WaveFlowModule(dg.Layer):
|
||||
"""WaveFlow model implementation.
|
||||
|
||||
Args:
|
||||
config (obj): model configuration parameters.
|
||||
|
||||
Returns:
|
||||
WaveFlowModule
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(WaveFlowModule, self).__init__()
|
||||
self.n_flows = config.n_flows
|
||||
|
@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer):
|
|||
self.perms.append(perm)
|
||||
|
||||
def forward(self, audio, mel):
|
||||
"""Training forward pass.
|
||||
|
||||
Use a conditioner to upsample mel spectrograms into hidden states.
|
||||
These hidden states along with the audio are passed to a stack of Flow
|
||||
modules to obtain the final latent variable z and a list of log scaling
|
||||
variables, which are then passed to the WaveFlowLoss module to calculate
|
||||
the negative log likelihood.
|
||||
|
||||
Args:
|
||||
audio (obj): audio samples.
|
||||
mel (obj): mel spectrograms.
|
||||
|
||||
Returns:
|
||||
z (obj): latent variable.
|
||||
log_s_list(list): list of log scaling variables.
|
||||
"""
|
||||
mel = self.conditioner(mel)
|
||||
assert mel.shape[2] >= audio.shape[1]
|
||||
# Prune out the tail of audio/mel so that time/n_group == 0.
|
||||
|
@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer):
|
|||
return z, log_s_list
|
||||
|
||||
def synthesize(self, mel, sigma=1.0):
|
||||
"""Use model to synthesize waveform.
|
||||
|
||||
Use a conditioner to upsample mel spectrograms into hidden states.
|
||||
These hidden states along with initial random gaussian latent variable
|
||||
are passed to a stack of Flow modules to obtain the audio output.
|
||||
|
||||
Args:
|
||||
mel (obj): mel spectrograms.
|
||||
sigma (float, optional): standard deviation of the guassian latent
|
||||
variable. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
audio (obj): synthesized audio.
|
||||
"""
|
||||
if self.dtype == "float16":
|
||||
mel = fluid.layers.cast(mel, self.dtype)
|
||||
mel = self.conditioner.infer(mel)
|
||||
|
|
|
@ -53,11 +53,9 @@ class DynamicGRU(dg.Layer):
|
|||
if self.is_reverse:
|
||||
i = inputs.shape[1] - 1 - i
|
||||
input_ = inputs[:, i:i + 1, :]
|
||||
input_ = layers.reshape(
|
||||
input_, [-1, input_.shape[2]], inplace=False)
|
||||
input_ = layers.reshape(input_, [-1, input_.shape[2]])
|
||||
hidden, reset, gate = self.gru_unit(input_, hidden)
|
||||
hidden_ = layers.reshape(
|
||||
hidden, [-1, 1, hidden.shape[1]], inplace=False)
|
||||
hidden_ = layers.reshape(hidden, [-1, 1, hidden.shape[1]])
|
||||
res.append(hidden_)
|
||||
if self.is_reverse:
|
||||
res = res[::-1]
|
||||
|
|
|
@ -71,7 +71,8 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
x = self.w_2(layers.relu(self.w_1(x)))
|
||||
|
||||
# dropout
|
||||
x = layers.dropout(x, self.dropout)
|
||||
x = layers.dropout(
|
||||
x, self.dropout, dropout_implementation='upscale_in_train')
|
||||
|
||||
x = layers.transpose(x, [0, 2, 1])
|
||||
# residual connection
|
||||
|
|
|
@ -0,0 +1,610 @@
|
|||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle import fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import conv
|
||||
from . import weight_norm
|
||||
|
||||
|
||||
def FC(name_scope,
|
||||
in_features,
|
||||
size,
|
||||
num_flatten_dims=1,
|
||||
relu=False,
|
||||
dropout=0.0,
|
||||
epsilon=1e-30,
|
||||
act=None,
|
||||
is_test=False,
|
||||
dtype="float32"):
|
||||
"""
|
||||
A special Linear Layer, when it is used with dropout, the weight is
|
||||
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
|
||||
"""
|
||||
|
||||
# stds
|
||||
if isinstance(in_features, int):
|
||||
in_features = [in_features]
|
||||
|
||||
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
|
||||
if relu:
|
||||
stds = [std * np.sqrt(2.0) for std in stds]
|
||||
|
||||
weight_inits = [
|
||||
fluid.initializer.NormalInitializer(scale=std) for std in stds
|
||||
]
|
||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
||||
|
||||
# param attrs
|
||||
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
|
||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
||||
|
||||
layer = weight_norm.FC(name_scope,
|
||||
size,
|
||||
num_flatten_dims=num_flatten_dims,
|
||||
param_attr=weight_attrs,
|
||||
bias_attr=bias_attr,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
return layer
|
||||
|
||||
|
||||
def Conv1D(name_scope,
|
||||
in_channels,
|
||||
num_filters,
|
||||
filter_size=3,
|
||||
dilation=1,
|
||||
groups=None,
|
||||
causal=False,
|
||||
std_mul=1.0,
|
||||
dropout=0.0,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype="float32"):
|
||||
"""
|
||||
A special Conv1D Layer, when it is used with dropout, the weight is
|
||||
initialized as
|
||||
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features)))
|
||||
"""
|
||||
# std
|
||||
std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels))
|
||||
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
|
||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
||||
|
||||
# param attrs
|
||||
weight_attr = fluid.ParamAttr(initializer=weight_init)
|
||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
||||
|
||||
layer = conv.Conv1D(
|
||||
name_scope,
|
||||
in_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
dilation,
|
||||
groups=groups,
|
||||
causal=causal,
|
||||
param_attr=weight_attr,
|
||||
bias_attr=bias_attr,
|
||||
use_cudnn=use_cudnn,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
return layer
|
||||
|
||||
|
||||
def Embedding(name_scope,
|
||||
num_embeddings,
|
||||
embed_dim,
|
||||
is_sparse=False,
|
||||
is_distributed=False,
|
||||
padding_idx=None,
|
||||
std=0.01,
|
||||
dtype="float32"):
|
||||
# param attrs
|
||||
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
|
||||
scale=std))
|
||||
layer = dg.Embedding(
|
||||
name_scope, (num_embeddings, embed_dim),
|
||||
padding_idx=padding_idx,
|
||||
param_attr=weight_attr,
|
||||
dtype=dtype)
|
||||
return layer
|
||||
|
||||
|
||||
class Conv1DGLU(dg.Layer):
|
||||
"""
|
||||
A Convolution 1D block with GLU activation. It also applys dropout for the
|
||||
input x. It fuses speaker embeddings through a FC activated by softsign. It
|
||||
has residual connection from the input x, and scale the output by
|
||||
np.sqrt(0.5).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name_scope,
|
||||
n_speakers,
|
||||
speaker_dim,
|
||||
in_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
dilation,
|
||||
std_mul=4.0,
|
||||
dropout=0.0,
|
||||
causal=False,
|
||||
residual=True,
|
||||
dtype="float32"):
|
||||
super(Conv1DGLU, self).__init__(name_scope, dtype=dtype)
|
||||
|
||||
# conv spec
|
||||
self.in_channels = in_channels
|
||||
self.n_speakers = n_speakers
|
||||
self.speaker_dim = speaker_dim
|
||||
self.num_filters = num_filters
|
||||
self.filter_size = filter_size
|
||||
self.dilation = dilation
|
||||
self.causal = causal
|
||||
self.residual = residual
|
||||
|
||||
# weight init and dropout
|
||||
self.std_mul = std_mul
|
||||
self.dropout = dropout
|
||||
|
||||
if residual:
|
||||
assert (
|
||||
in_channels == num_filters
|
||||
), "this block uses residual connection"\
|
||||
"the input_channes should equals num_filters"
|
||||
|
||||
self.conv = Conv1D(
|
||||
self.full_name(),
|
||||
in_channels,
|
||||
2 * num_filters,
|
||||
filter_size,
|
||||
dilation,
|
||||
causal=causal,
|
||||
std_mul=std_mul,
|
||||
dropout=dropout,
|
||||
dtype=dtype)
|
||||
|
||||
if n_speakers > 1:
|
||||
assert (speaker_dim is not None
|
||||
), "speaker embed should not be null in multi-speaker case"
|
||||
self.fc = Conv1D(
|
||||
self.full_name(),
|
||||
speaker_dim,
|
||||
num_filters,
|
||||
filter_size=1,
|
||||
dilation=1,
|
||||
causal=False,
|
||||
act="softsign",
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, x, speaker_embed_bc1t=None):
|
||||
"""
|
||||
Args:
|
||||
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
|
||||
layer, where B means batch_size, C_in means the input channels
|
||||
T means input time steps.
|
||||
speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded
|
||||
speaker embed, where C_sp means speaker embedding size. Note
|
||||
that when using residual connection, the Conv1DGLU does not
|
||||
change the number of channels, so out channels equals input
|
||||
channels.
|
||||
|
||||
Returns:
|
||||
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
|
||||
C_out means the output channels of Conv1DGLU.
|
||||
"""
|
||||
|
||||
residual = x
|
||||
x = fluid.layers.dropout(x, self.dropout)
|
||||
x = self.conv(x)
|
||||
|
||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
||||
|
||||
if speaker_embed_bc1t is not None:
|
||||
sp = self.fc(speaker_embed_bc1t)
|
||||
content = content + sp
|
||||
|
||||
# glu
|
||||
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
|
||||
|
||||
if self.residual:
|
||||
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
|
||||
return x
|
||||
|
||||
def add_input(self, x, speaker_embed_bc11=None):
|
||||
"""
|
||||
Inputs:
|
||||
x: shape(B, num_filters, 1, time_steps)
|
||||
speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps)
|
||||
|
||||
Outputs:
|
||||
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
|
||||
"""
|
||||
|
||||
residual = x
|
||||
|
||||
# add step input and produce step output
|
||||
x = fluid.layers.dropout(x, self.dropout)
|
||||
x = self.conv.add_input(x)
|
||||
|
||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
||||
|
||||
if speaker_embed_bc11 is not None:
|
||||
sp = self.fc(speaker_embed_bc11)
|
||||
content = content + sp
|
||||
|
||||
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
|
||||
|
||||
if self.residual:
|
||||
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
|
||||
return x
|
||||
|
||||
|
||||
def Conv1DTranspose(name_scope,
|
||||
in_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
padding=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=None,
|
||||
std_mul=1.0,
|
||||
dropout=0.0,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype="float32"):
|
||||
std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size))
|
||||
weight_init = fluid.initializer.NormalInitializer(scale=std)
|
||||
weight_attr = fluid.ParamAttr(initializer=weight_init)
|
||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
||||
layer = conv.Conv1DTranspose(
|
||||
name_scope,
|
||||
in_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
param_attr=weight_attr,
|
||||
bias_attr=bias_attr,
|
||||
use_cudnn=use_cudnn,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
return layer
|
||||
|
||||
|
||||
def compute_position_embedding(rad):
|
||||
# rad is a transposed radius, shape(embed_dim, n_vocab)
|
||||
embed_dim, n_vocab = rad.shape
|
||||
|
||||
even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32"))
|
||||
odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32"))
|
||||
|
||||
even_rads = fluid.layers.gather(rad, even_dims)
|
||||
odd_rads = fluid.layers.gather(rad, odd_dims)
|
||||
|
||||
sines = fluid.layers.sin(even_rads)
|
||||
cosines = fluid.layers.cos(odd_rads)
|
||||
|
||||
temp = fluid.layers.scatter(rad, even_dims, sines)
|
||||
out = fluid.layers.scatter(temp, odd_dims, cosines)
|
||||
out = fluid.layers.transpose(out, perm=[1, 0])
|
||||
return out
|
||||
|
||||
|
||||
def position_encoding_init(n_position,
|
||||
d_pos_vec,
|
||||
position_rate=1.0,
|
||||
sinusoidal=True):
|
||||
""" Init the sinusoid position encoding table """
|
||||
|
||||
# keep idx 0 for padding token position encoding zero vector
|
||||
position_enc = np.array([[
|
||||
position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec)
|
||||
for i in range(d_pos_vec)
|
||||
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
|
||||
|
||||
if sinusoidal:
|
||||
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
|
||||
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
|
||||
|
||||
return position_enc
|
||||
|
||||
|
||||
class PositionEmbedding(dg.Layer):
|
||||
def __init__(self,
|
||||
name_scope,
|
||||
n_position,
|
||||
d_pos_vec,
|
||||
position_rate=1.0,
|
||||
is_sparse=False,
|
||||
is_distributed=False,
|
||||
param_attr=None,
|
||||
max_norm=None,
|
||||
padding_idx=None,
|
||||
dtype="float32"):
|
||||
super(PositionEmbedding, self).__init__(name_scope, dtype=dtype)
|
||||
self.embed = dg.Embedding(
|
||||
self.full_name(),
|
||||
size=(n_position, d_pos_vec),
|
||||
is_sparse=is_sparse,
|
||||
is_distributed=is_distributed,
|
||||
padding_idx=None,
|
||||
param_attr=param_attr,
|
||||
dtype=dtype)
|
||||
self.set_weight(
|
||||
position_encoding_init(
|
||||
n_position,
|
||||
d_pos_vec,
|
||||
position_rate=position_rate,
|
||||
sinusoidal=False).astype(dtype))
|
||||
|
||||
self._is_sparse = is_sparse
|
||||
self._is_distributed = is_distributed
|
||||
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
|
||||
if self._remote_prefetch:
|
||||
assert self._is_sparse is True and self._is_distributed is False
|
||||
|
||||
self._padding_idx = (-1 if padding_idx is None else padding_idx if
|
||||
padding_idx >= 0 else (n_position + padding_idx))
|
||||
self._position_rate = position_rate
|
||||
self._max_norm = max_norm
|
||||
self._dtype = dtype
|
||||
|
||||
def set_weight(self, array):
|
||||
assert self.embed._w.shape == list(array.shape), "shape does not match"
|
||||
self.embed._w._ivar.value().get_tensor().set(
|
||||
array, fluid.framework._current_expected_place())
|
||||
|
||||
def forward(self, indices, speaker_position_rate=None):
|
||||
"""
|
||||
Args:
|
||||
indices (Variable): Shape (B, T, 1), dtype: int64, position
|
||||
indices, where B means the batch size, T means the time steps.
|
||||
speaker_position_rate (Variable | float, optional), position
|
||||
rate. It can be a float point number or a Variable with
|
||||
shape (1,), then this speaker_position_rate is used for every
|
||||
example. It can also be a Variable with shape (B, 1), which
|
||||
contains a speaker position rate for each speaker.
|
||||
Returns:
|
||||
out (Variable): Shape(B, C_pos), position embedding, where C_pos
|
||||
means position embedding size.
|
||||
"""
|
||||
rad = fluid.layers.transpose(self.embed._w, perm=[1, 0])
|
||||
batch_size = indices.shape[0]
|
||||
|
||||
if speaker_position_rate is None:
|
||||
weight = compute_position_embedding(rad)
|
||||
out = self._helper.create_variable_for_type_inference(self._dtype)
|
||||
self._helper.append_op(
|
||||
type="lookup_table",
|
||||
inputs={"Ids": indices,
|
||||
"W": weight},
|
||||
outputs={"Out": out},
|
||||
attrs={
|
||||
"is_sparse": self._is_sparse,
|
||||
"is_distributed": self._is_distributed,
|
||||
"remote_prefetch": self._remote_prefetch,
|
||||
"padding_idx":
|
||||
self._padding_idx, # special value for lookup table op
|
||||
})
|
||||
return out
|
||||
|
||||
elif (np.isscalar(speaker_position_rate) or
|
||||
isinstance(speaker_position_rate, fluid.framework.Variable) and
|
||||
speaker_position_rate.shape == [1, 1]):
|
||||
# # make a weight
|
||||
# scale the weight (the operand for sin & cos)
|
||||
if np.isscalar(speaker_position_rate):
|
||||
scaled_rad = fluid.layers.scale(rad, speaker_position_rate)
|
||||
else:
|
||||
scaled_rad = fluid.layers.elementwise_mul(
|
||||
rad, speaker_position_rate[0])
|
||||
weight = compute_position_embedding(scaled_rad)
|
||||
out = self._helper.create_variable_for_type_inference(self._dtype)
|
||||
self._helper.append_op(
|
||||
type="lookup_table",
|
||||
inputs={"Ids": indices,
|
||||
"W": weight},
|
||||
outputs={"Out": out},
|
||||
attrs={
|
||||
"is_sparse": self._is_sparse,
|
||||
"is_distributed": self._is_distributed,
|
||||
"remote_prefetch": self._remote_prefetch,
|
||||
"padding_idx":
|
||||
self._padding_idx, # special value for lookup table op
|
||||
})
|
||||
return out
|
||||
|
||||
elif np.prod(speaker_position_rate.shape) > 1:
|
||||
assert speaker_position_rate.shape == [batch_size, 1]
|
||||
outputs = []
|
||||
for i in range(batch_size):
|
||||
rate = speaker_position_rate[i] # rate has shape [1]
|
||||
scaled_rad = fluid.layers.elementwise_mul(rad, rate)
|
||||
weight = compute_position_embedding(scaled_rad)
|
||||
out = self._helper.create_variable_for_type_inference(
|
||||
self._dtype)
|
||||
sequence = indices[i]
|
||||
self._helper.append_op(
|
||||
type="lookup_table",
|
||||
inputs={"Ids": sequence,
|
||||
"W": weight},
|
||||
outputs={"Out": out},
|
||||
attrs={
|
||||
"is_sparse": self._is_sparse,
|
||||
"is_distributed": self._is_distributed,
|
||||
"remote_prefetch": self._remote_prefetch,
|
||||
"padding_idx": -1,
|
||||
})
|
||||
outputs.append(out)
|
||||
out = fluid.layers.stack(outputs)
|
||||
return out
|
||||
else:
|
||||
raise Exception("Then you can just use position rate at init")
|
||||
|
||||
|
||||
class Conv1D_GU(dg.Layer):
|
||||
def __init__(self,
|
||||
name_scope,
|
||||
conditioner_dim,
|
||||
in_channels,
|
||||
num_filters,
|
||||
filter_size,
|
||||
dilation,
|
||||
causal=False,
|
||||
residual=True,
|
||||
dtype="float32"):
|
||||
super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
|
||||
|
||||
self.conditioner_dim = conditioner_dim
|
||||
self.in_channels = in_channels
|
||||
self.num_filters = num_filters
|
||||
self.filter_size = filter_size
|
||||
self.dilation = dilation
|
||||
self.causal = causal
|
||||
self.residual = residual
|
||||
|
||||
if residual:
|
||||
assert (
|
||||
in_channels == num_filters
|
||||
), "this block uses residual connection"\
|
||||
"the input_channels should equals num_filters"
|
||||
|
||||
self.conv = Conv1D(
|
||||
self.full_name(),
|
||||
in_channels,
|
||||
2 * num_filters,
|
||||
filter_size,
|
||||
dilation,
|
||||
causal=causal,
|
||||
dtype=dtype)
|
||||
|
||||
self.fc = Conv1D(
|
||||
self.full_name(),
|
||||
conditioner_dim,
|
||||
2 * num_filters,
|
||||
filter_size=1,
|
||||
dilation=1,
|
||||
causal=False,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, x, skip=None, conditioner=None):
|
||||
"""
|
||||
Args:
|
||||
x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
|
||||
layer, where B means batch_size, C_in means the input channels
|
||||
T means input time steps.
|
||||
skip (Variable): Shape(B, C_in, 1, T), skip connection.
|
||||
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
|
||||
conditioner, where C_con is conditioner hidden dim which
|
||||
equals the num of mel bands. Note that when using residual
|
||||
connection, the Conv1D_GU does not change the number of
|
||||
channels, so out channels equals input channels.
|
||||
Returns:
|
||||
x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
|
||||
C_out means the output channels of Conv1D_GU.
|
||||
skip (Variable): Shape(B, C_out, 1, T), skip connection.
|
||||
"""
|
||||
residual = x
|
||||
x = self.conv(x)
|
||||
|
||||
if conditioner is not None:
|
||||
cond_bias = self.fc(conditioner)
|
||||
x += cond_bias
|
||||
|
||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
||||
|
||||
# Gated Unit.
|
||||
x = fluid.layers.elementwise_mul(
|
||||
fluid.layers.sigmoid(gate), fluid.layers.tanh(content))
|
||||
|
||||
if skip is None:
|
||||
skip = x
|
||||
else:
|
||||
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
|
||||
|
||||
if self.residual:
|
||||
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
|
||||
|
||||
return x, skip
|
||||
|
||||
def add_input(self, x, skip=None, conditioner=None):
|
||||
"""
|
||||
Inputs:
|
||||
x: shape(B, num_filters, 1, time_steps)
|
||||
skip: shape(B, num_filters, 1, time_steps), skip connection
|
||||
conditioner: shape(B, conditioner_dim, 1, time_steps)
|
||||
Outputs:
|
||||
x: shape(B, num_filters, 1, time_steps), where time_steps = 1
|
||||
skip: skip connection, same shape as x
|
||||
"""
|
||||
residual = x
|
||||
|
||||
# add step input and produce step output
|
||||
x = self.conv.add_input(x)
|
||||
|
||||
if conditioner is not None:
|
||||
cond_bias = self.fc(conditioner)
|
||||
x += cond_bias
|
||||
|
||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
||||
|
||||
# Gated Unit.
|
||||
x = fluid.layers.elementwise_mul(
|
||||
fluid.layers.sigmoid(gate), fluid.layers.tanh(content))
|
||||
|
||||
if skip is None:
|
||||
skip = x
|
||||
else:
|
||||
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
|
||||
|
||||
if self.residual:
|
||||
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
|
||||
|
||||
return x, skip
|
||||
|
||||
|
||||
def Conv2DTranspose(name_scope,
|
||||
num_filters,
|
||||
filter_size,
|
||||
padding=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
use_cudnn=True,
|
||||
act=None,
|
||||
dtype="float32"):
|
||||
val = 1.0 / (filter_size[0] * filter_size[1])
|
||||
weight_init = fluid.initializer.ConstantInitializer(val)
|
||||
weight_attr = fluid.ParamAttr(initializer=weight_init)
|
||||
|
||||
layer = weight_norm.Conv2DTranspose(
|
||||
name_scope,
|
||||
num_filters,
|
||||
filter_size=filter_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
param_attr=weight_attr,
|
||||
use_cudnn=use_cudnn,
|
||||
act=act,
|
||||
dtype=dtype)
|
||||
|
||||
return layer
|
|
@ -78,17 +78,15 @@ class ScaledDotProductAttention(dg.Layer):
|
|||
"""
|
||||
# Compute attention score
|
||||
attention = layers.matmul(
|
||||
query, key, transpose_y=True) #transpose the last dim in y
|
||||
attention = attention / math.sqrt(self.d_key)
|
||||
query, key, transpose_y=True, alpha=self.d_key
|
||||
**-0.5) #transpose the last dim in y
|
||||
|
||||
# Mask key to ignore padding
|
||||
if mask is not None:
|
||||
attention = attention * mask
|
||||
mask = (mask == 0).astype(np.float32) * (-2**32 + 1)
|
||||
attention = attention + mask
|
||||
|
||||
attention = layers.softmax(attention)
|
||||
attention = layers.dropout(attention, dropout)
|
||||
attention = layers.dropout(
|
||||
attention, dropout, dropout_implementation='upscale_in_train')
|
||||
|
||||
# Mask query to ignore padding
|
||||
if query_mask is not None:
|
||||
|
@ -142,17 +140,11 @@ class MultiheadAttention(dg.Layer):
|
|||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
|
||||
batch_size = key.shape[0]
|
||||
seq_len_key = key.shape[1]
|
||||
seq_len_query = query_input.shape[1]
|
||||
|
||||
# repeat masks h times
|
||||
if query_mask is not None:
|
||||
query_mask = layers.expand(query_mask,
|
||||
[self.num_head, 1, seq_len_key])
|
||||
if mask is not None:
|
||||
mask = layers.expand(mask, (self.num_head, 1, 1))
|
||||
|
||||
# Make multihead attention
|
||||
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
|
||||
key = layers.reshape(
|
||||
|
@ -176,6 +168,18 @@ class MultiheadAttention(dg.Layer):
|
|||
result, attention = self.scal_attn(
|
||||
key, value, query, mask=mask, query_mask=query_mask)
|
||||
|
||||
key = layers.reshape(
|
||||
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
||||
value = layers.reshape(
|
||||
layers.transpose(value, [2, 0, 1, 3]),
|
||||
[-1, seq_len_key, self.d_k])
|
||||
query = layers.reshape(
|
||||
layers.transpose(query, [2, 0, 1, 3]),
|
||||
[-1, seq_len_query, self.d_q])
|
||||
|
||||
result, attention = self.scal_attn(
|
||||
key, value, query, mask=mask, query_mask=query_mask)
|
||||
|
||||
# concat all multihead result
|
||||
result = layers.reshape(
|
||||
result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
||||
|
@ -184,7 +188,10 @@ class MultiheadAttention(dg.Layer):
|
|||
[batch_size, seq_len_query, -1])
|
||||
if self.is_concat:
|
||||
result = layers.concat([query_input, result], axis=-1)
|
||||
result = layers.dropout(self.fc(result), self.dropout)
|
||||
result = layers.dropout(
|
||||
self.fc(result),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
result = result + query_input
|
||||
|
||||
result = self.layer_norm(result)
|
||||
|
|
Loading…
Reference in New Issue