From 078d22e51cb064300c8b60a21380f20c5acf5f66 Mon Sep 17 00:00:00 2001 From: lifuchen Date: Thu, 5 Mar 2020 07:08:12 +0000 Subject: [PATCH 1/4] Modified data.py to generate masks as models inputs --- examples/fastspeech/README.md | 2 +- .../{config => configs}/fastspeech.yaml | 0 .../{config => configs}/synthesis.yaml | 4 +-- examples/fastspeech/parse.py | 3 ++ examples/fastspeech/synthesis.py | 18 ++++++++-- examples/fastspeech/train.py | 23 ++++++++++--- examples/fastspeech/train.sh | 2 +- examples/transformer_tts/README.md | 6 ++-- .../{config => configs}/synthesis.yaml | 5 ++- .../train_transformer.yaml | 0 .../{config => configs}/train_vocoder.yaml | 0 examples/transformer_tts/data.py | 15 +++++++-- examples/transformer_tts/synthesis.py | 27 ++++++++++++--- examples/transformer_tts/synthesis.sh | 6 ++-- examples/transformer_tts/train_transformer.py | 17 +++++++--- examples/transformer_tts/train_transformer.sh | 2 +- parakeet/data/dataset.py | 12 +++++++ parakeet/models/fastspeech/decoder.py | 8 ++--- parakeet/models/fastspeech/encoder.py | 14 ++++---- parakeet/models/fastspeech/fastspeech.py | 19 ++++++++--- parakeet/models/fastspeech/fft_block.py | 3 +- .../models/fastspeech/length_regulator.py | 4 +-- parakeet/models/fastspeech/utils.py | 2 +- parakeet/models/transformer_tts/cbhg.py | 1 + parakeet/models/transformer_tts/decoder.py | 33 ++++++++++--------- parakeet/models/transformer_tts/encoder.py | 19 ++++++----- .../models/transformer_tts/encoderprenet.py | 7 ++-- .../models/transformer_tts/post_convnet.py | 6 ++-- parakeet/models/transformer_tts/prenet.py | 4 +-- .../models/transformer_tts/transformer_tts.py | 13 ++++---- parakeet/models/transformer_tts/utils.py | 20 +++++++++-- parakeet/modules/dynamic_gru.py | 4 +-- parakeet/modules/ffn.py | 2 +- parakeet/modules/modules.py | 4 +-- parakeet/modules/multihead_attention.py | 21 ++++-------- 35 files changed, 216 insertions(+), 110 deletions(-) rename examples/fastspeech/{config => configs}/fastspeech.yaml (100%) rename examples/fastspeech/{config => configs}/synthesis.yaml (88%) rename examples/transformer_tts/{config => configs}/synthesis.yaml (72%) rename examples/transformer_tts/{config => configs}/train_transformer.yaml (100%) rename examples/transformer_tts/{config => configs}/train_vocoder.yaml (100%) diff --git a/examples/fastspeech/README.md b/examples/fastspeech/README.md index 007b6b2..1908cd6 100644 --- a/examples/fastspeech/README.md +++ b/examples/fastspeech/README.md @@ -55,7 +55,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr --config_path='config/fastspeech.yaml' \ ``` -if you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--fastspeech_step`` +If you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--fastspeech_step`` For more help on arguments: ``python train.py --help``. diff --git a/examples/fastspeech/config/fastspeech.yaml b/examples/fastspeech/configs/fastspeech.yaml similarity index 100% rename from examples/fastspeech/config/fastspeech.yaml rename to examples/fastspeech/configs/fastspeech.yaml diff --git a/examples/fastspeech/config/synthesis.yaml b/examples/fastspeech/configs/synthesis.yaml similarity index 88% rename from examples/fastspeech/config/synthesis.yaml rename to examples/fastspeech/configs/synthesis.yaml index 9a43dff..ab9dbb4 100644 --- a/examples/fastspeech/config/synthesis.yaml +++ b/examples/fastspeech/configs/synthesis.yaml @@ -3,8 +3,8 @@ audio: n_fft: 2048 sr: 22050 preemphasis: 0.97 - hop_length: 275 - win_length: 1102 + hop_length: 256 + win_length: 1024 power: 1.2 min_level_db: -100 ref_level_db: 20 diff --git a/examples/fastspeech/parse.py b/examples/fastspeech/parse.py index a6c2d99..87a804d 100644 --- a/examples/fastspeech/parse.py +++ b/examples/fastspeech/parse.py @@ -17,6 +17,9 @@ def add_config_options_to_parser(parser): help="use gpu or not during training.") parser.add_argument('--use_data_parallel', type=int, default=0, help="use data parallel or not during training.") + parser.add_argument('--alpha', type=float, default=1.0, + help="The hyperparameter to determine the length of the expanded sequence \ + mel, thereby controlling the voice speed.") parser.add_argument('--data_path', type=str, default='./dataset/LJSpeech-1.1', help="the path of dataset.") diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py index 6a3d146..f9e944b 100644 --- a/examples/fastspeech/synthesis.py +++ b/examples/fastspeech/synthesis.py @@ -11,6 +11,7 @@ import paddle.fluid.dygraph as dg from parakeet.g2p.en import text_to_sequence from parakeet import audio from parakeet.models.fastspeech.fastspeech import FastSpeech +from parakeet.models.transformer_tts.utils import * def load_checkpoint(step, model_path): model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step)) @@ -41,11 +42,22 @@ def synthesis(text_input, args): model.eval() text = np.asarray(text_to_sequence(text_input)) - text = fluid.layers.unsqueeze(dg.to_variable(text),[0]) + text = np.expand_dims(text, axis=0) pos_text = np.arange(1, text.shape[1]+1) - pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text),[0]) + pos_text = np.expand_dims(pos_text, axis=0) + enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32) + enc_slf_attn_mask = get_attn_key_pad_mask(pos_text, text).astype(np.float32) + + text = dg.to_variable(text) + pos_text = dg.to_variable(pos_text) + enc_non_pad_mask = dg.to_variable(enc_non_pad_mask) + enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask) - mel_output, mel_output_postnet = model(text, pos_text, alpha=args.alpha) + mel_output, mel_output_postnet = model(text, pos_text, alpha=args.alpha, + enc_non_pad_mask=enc_non_pad_mask, + enc_slf_attn_mask=enc_slf_attn_mask, + dec_non_pad_mask=None, + dec_slf_attn_mask=None) _ljspeech_processor = audio.AudioProcessor( sample_rate=cfg['audio']['sr'], diff --git a/examples/fastspeech/train.py b/examples/fastspeech/train.py index 52b5725..a5b5bea 100644 --- a/examples/fastspeech/train.py +++ b/examples/fastspeech/train.py @@ -8,6 +8,7 @@ from parse import add_config_options_to_parser from pprint import pprint from ruamel import yaml from tqdm import tqdm +from matplotlib import cm from collections import OrderedDict from tensorboardX import SummaryWriter import paddle.fluid.dygraph as dg @@ -77,18 +78,32 @@ def main(args): for i, data in enumerate(pbar): pbar.set_description('Processing at epoch %d'%epoch) - character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens = data + (character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens, + enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data - _, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel) - alignment = dg.to_variable(get_alignment(attn_probs, mel_lens, cfg['transformer_head'])).astype(np.float32) + _, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel, + dec_slf_mask=dec_slf_mask, + enc_slf_mask=enc_slf_mask, enc_query_mask=enc_query_mask, + enc_dec_mask=enc_dec_mask, dec_query_slf_mask=dec_query_slf_mask, + dec_query_mask=dec_query_mask) + alignment, max_attn = get_alignment(attn_probs, mel_lens, cfg['transformer_head']) + alignment = dg.to_variable(alignment).astype(np.float32) + if local_rank==0 and global_step % 5 == 1: + x = np.uint8(cm.viridis(max_attn[8,:mel_lens.numpy()[8]]) * 255) + writer.add_image('Attention_%d_0'%global_step, x, 0, dataformats="HWC") + global_step += 1 #Forward result= model(character, pos_text, mel_pos=pos_mel, - length_target=alignment) + length_target=alignment, + enc_non_pad_mask=enc_query_mask, + enc_slf_attn_mask=enc_slf_mask, + dec_non_pad_mask=dec_query_slf_mask, + dec_slf_attn_mask=dec_slf_mask) mel_output, mel_output_postnet, duration_predictor_output, _, _ = result mel_loss = layers.mse_loss(mel_output, mel) mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel) diff --git a/examples/fastspeech/train.sh b/examples/fastspeech/train.sh index d293c0c..11e78c4 100644 --- a/examples/fastspeech/train.sh +++ b/examples/fastspeech/train.sh @@ -1,6 +1,6 @@ # train model # if you wish to resume from an exists model, uncomment --checkpoint_path and --fastspeech_step -CUDA_VISIBLE_DEVICES=0\ +export CUDA_VISIBLE_DEVICES=0 python -u train.py \ --batch_size=32 \ --epochs=10000 \ diff --git a/examples/transformer_tts/README.md b/examples/transformer_tts/README.md index afdfdd2..ab0aed4 100644 --- a/examples/transformer_tts/README.md +++ b/examples/transformer_tts/README.md @@ -1,5 +1,5 @@ # TransformerTTS -Paddle fluid implementation of TransformerTTS, a neural TTS with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895). +PaddlePaddle fluid implementation of TransformerTTS, a neural TTS with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895). ## Dataset @@ -48,7 +48,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr --config_path='config/train_transformer.yaml' \ ``` -if you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--transformer_step`` +If you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--transformer_step`` For more help on arguments: ``python train_transformer.py --help``. @@ -76,7 +76,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr --data_path=${DATAPATH} \ --config_path='config/train_vocoder.yaml' \ ``` -if you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--vocoder_step`` +If you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--vocoder_step`` For more help on arguments: ``python train_vocoder.py --help``. diff --git a/examples/transformer_tts/config/synthesis.yaml b/examples/transformer_tts/configs/synthesis.yaml similarity index 72% rename from examples/transformer_tts/config/synthesis.yaml rename to examples/transformer_tts/configs/synthesis.yaml index 217dd85..c23b029 100644 --- a/examples/transformer_tts/config/synthesis.yaml +++ b/examples/transformer_tts/configs/synthesis.yaml @@ -8,4 +8,7 @@ audio: power: 1.2 min_level_db: -100 ref_level_db: 20 - outputs_per_step: 1 \ No newline at end of file + outputs_per_step: 1 + +hidden_size: 256 +embedding_size: 512 \ No newline at end of file diff --git a/examples/transformer_tts/config/train_transformer.yaml b/examples/transformer_tts/configs/train_transformer.yaml similarity index 100% rename from examples/transformer_tts/config/train_transformer.yaml rename to examples/transformer_tts/configs/train_transformer.yaml diff --git a/examples/transformer_tts/config/train_vocoder.yaml b/examples/transformer_tts/configs/train_vocoder.yaml similarity index 100% rename from examples/transformer_tts/config/train_vocoder.yaml rename to examples/transformer_tts/configs/train_vocoder.yaml diff --git a/examples/transformer_tts/data.py b/examples/transformer_tts/data.py index 9401b7b..fcd167f 100644 --- a/examples/transformer_tts/data.py +++ b/examples/transformer_tts/data.py @@ -10,7 +10,8 @@ from parakeet import audio from parakeet.data.sampler import * from parakeet.data.datacargo import DataCargo from parakeet.data.batch import TextIDBatcher, SpecBatcher -from parakeet.data.dataset import DatasetMixin, TransformDataset +from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset +from parakeet.models.transformer_tts.utils import * class LJSpeechLoader: def __init__(self, config, args, nranks, rank, is_vocoder=False, shuffle=True): @@ -20,6 +21,8 @@ class LJSpeechLoader: metadata = LJSpeechMetaData(LJSPEECH_ROOT) transformer = LJSpeech(config) dataset = TransformDataset(metadata, transformer) + dataset = CacheDataset(dataset) + sampler = DistributedSampler(len(metadata), nranks, rank, shuffle=shuffle) assert args.batch_size % nranks == 0 @@ -132,7 +135,15 @@ def batch_examples(batch): pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T) mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1)) #(B,T,num_mels) mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))#(B,T,num_mels) - return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens), np.array(mel_lens)) + enc_slf_mask = get_attn_key_pad_mask(pos_texts, texts).astype(np.float32) + enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32) + dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,mel_inputs).astype(np.float32) + enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:,:,0], mel_inputs).astype(np.float32) + dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32) + dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32) + + return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens), np.array(mel_lens), + enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask) def batch_examples_vocoder(batch): mels=[] diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index fb1bd2f..f9a4823 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -3,6 +3,7 @@ from scipy.io.wavfile import write from parakeet.g2p.en import text_to_sequence import numpy as np from tqdm import tqdm +from matplotlib import cm from tensorboardX import SummaryWriter from ruamel import yaml import paddle.fluid as fluid @@ -12,6 +13,7 @@ import argparse from parse import add_config_options_to_parser from pprint import pprint from collections import OrderedDict +from parakeet.models.transformer_tts.utils import * from parakeet import audio from parakeet.models.transformer_tts.vocoder import Vocoder from parakeet.models.transformer_tts.transformer_tts import TransformerTTS @@ -55,15 +57,17 @@ def synthesis(text_input, args): mel_input = dg.to_variable(np.zeros([1,1,80])).astype(np.float32) pos_text = np.arange(1, text.shape[1]+1) pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text),[0]) - + pbar = tqdm(range(args.max_len)) - for i in pbar: + dec_slf_mask = get_triu_tensor(mel_input.numpy(), mel_input.numpy()).astype(np.float32) + dec_slf_mask = fluid.layers.cast(dg.to_variable(dec_slf_mask == 0), np.float32) pos_mel = np.arange(1, mel_input.shape[1]+1) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0]) - mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel) + mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel, dec_slf_mask) mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1) + mag_pred = model_vocoder(postnet_pred) _ljspeech_processor = audio.AudioProcessor( @@ -87,6 +91,21 @@ def synthesis(text_input, args): sound_norm=False) wav = _ljspeech_processor.inv_spectrogram(fluid.layers.transpose(fluid.layers.squeeze(mag_pred,[0]), [1,0]).numpy()) + global_step = 0 + for i, prob in enumerate(attn_probs): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j]) * 255) + writer.add_image('Attention_%d_0'%global_step, x, i*4+j, dataformats="HWC") + + for i, prob in enumerate(attn_enc): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j]) * 255) + writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC") + + for i, prob in enumerate(attn_dec): + for j in range(4): + x = np.uint8(cm.viridis(prob.numpy()[j]) * 255) + writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC") writer.add_audio(text_input, wav, 0, cfg['audio']['sr']) if not os.path.exists(args.sample_path): os.mkdir(args.sample_path) @@ -97,4 +116,4 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description="Synthesis model") add_config_options_to_parser(parser) args = parser.parse_args() - synthesis("Transformer model is so fast!", args) + synthesis("They emphasized the necessity that the information now being furnished be handled with judgment and care.", args) diff --git a/examples/transformer_tts/synthesis.sh b/examples/transformer_tts/synthesis.sh index 8cb137a..42b704d 100644 --- a/examples/transformer_tts/synthesis.sh +++ b/examples/transformer_tts/synthesis.sh @@ -2,10 +2,10 @@ # train model CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ ---max_len=50 \ +--max_len=600 \ --transformer_step=160000 \ ---vocoder_step=70000 \ ---use_gpu=1 +--vocoder_step=90000 \ +--use_gpu=1 \ --checkpoint_path='./checkpoint' \ --log_dir='./log' \ --sample_path='./sample' \ diff --git a/examples/transformer_tts/train_transformer.py b/examples/transformer_tts/train_transformer.py index cbca569..79c67c1 100644 --- a/examples/transformer_tts/train_transformer.py +++ b/examples/transformer_tts/train_transformer.py @@ -69,19 +69,24 @@ def main(args): pbar = tqdm(reader) for i, data in enumerate(pbar): pbar.set_description('Processing at epoch %d'%epoch) - character, mel, mel_input, pos_text, pos_mel, text_length, _ = data + character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask= data global_step += 1 - mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel) - - label = (pos_mel == 0).astype(np.float32) - + mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel, dec_slf_mask=dec_slf_mask, + enc_slf_mask=enc_slf_mask, enc_query_mask=enc_query_mask, + enc_dec_mask=enc_dec_mask, dec_query_slf_mask=dec_query_slf_mask, + dec_query_mask=dec_query_mask) + + mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel))) post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel))) loss = mel_loss + post_mel_loss + + # Note: When used stop token loss the learning did not work. if args.stop_token: + label = (pos_mel == 0).astype(np.float32) stop_loss = cross_entropy(stop_preds, label) loss = loss + stop_loss @@ -123,6 +128,7 @@ def main(args): x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC") + if args.use_data_parallel: loss = model.scale_loss(loss) loss.backward() @@ -132,6 +138,7 @@ def main(args): optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg['grad_clip_thresh'])) model.clear_gradients() + # save checkpoint if local_rank==0 and global_step % args.save_step == 0: if not os.path.exists(args.save_path): diff --git a/examples/transformer_tts/train_transformer.sh b/examples/transformer_tts/train_transformer.sh index cdb24cf..346d351 100644 --- a/examples/transformer_tts/train_transformer.sh +++ b/examples/transformer_tts/train_transformer.sh @@ -1,7 +1,7 @@ # train model # if you wish to resume from an exists model, uncomment --checkpoint_path and --transformer_step -CUDA_VISIBLE_DEVICES=0 \ +export CUDA_VISIBLE_DEVICES=2 python -u train_transformer.py \ --batch_size=32 \ --epochs=10000 \ diff --git a/parakeet/data/dataset.py b/parakeet/data/dataset.py index d9f9a1f..90c1360 100644 --- a/parakeet/data/dataset.py +++ b/parakeet/data/dataset.py @@ -1,5 +1,6 @@ import six import numpy as np +from tqdm import tqdm class DatasetMixin(object): @@ -45,6 +46,17 @@ class TransformDataset(DatasetMixin): in_data = self._dataset[i] return self._transform(in_data) +class CacheDataset(DatasetMixin): + def __init__(self, dataset): + self._dataset = dataset + pbar = tqdm(range(len(self._dataset))) + self._cache = [self._dataset[i] for i in pbar] + + def __len__(self): + return len(self._dataset) + + def get_example(self, i): + return self._cache[i] class TupleDataset(object): def __init__(self, *datasets): diff --git a/parakeet/models/fastspeech/decoder.py b/parakeet/models/fastspeech/decoder.py index 732fed4..8bcda34 100644 --- a/parakeet/models/fastspeech/decoder.py +++ b/parakeet/models/fastspeech/decoder.py @@ -18,6 +18,7 @@ class Decoder(dg.Layer): super(Decoder, self).__init__() n_position = len_max_seq + 1 + self.n_head = n_head self.pos_inp = get_sinusoid_encoding_table(n_position, d_model, padding_idx=0) self.position_enc = dg.Embedding(size=[n_position, d_model], padding_idx=0, @@ -28,7 +29,7 @@ class Decoder(dg.Layer): for i, layer in enumerate(self.layer_stack): self.add_sublayer('fft_{}'.format(i), layer) - def forward(self, enc_seq, enc_pos): + def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): """ Decoder layer of FastSpeech. @@ -42,10 +43,7 @@ class Decoder(dg.Layer): dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. """ dec_slf_attn_list = [] - - # -- Prepare masks - slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos) - non_pad_mask = get_non_pad_mask(enc_pos) + slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) # -- Forward dec_output = enc_seq + self.position_enc(enc_pos) diff --git a/parakeet/models/fastspeech/encoder.py b/parakeet/models/fastspeech/encoder.py index ac96e39..36e0f99 100644 --- a/parakeet/models/fastspeech/encoder.py +++ b/parakeet/models/fastspeech/encoder.py @@ -18,11 +18,12 @@ class Encoder(dg.Layer): dropout=0.1): super(Encoder, self).__init__() n_position = len_max_seq + 1 + self.n_head = n_head - self.src_word_emb = dg.Embedding(size=[n_src_vocab, d_model], padding_idx=0) + self.src_word_emb = dg.Embedding(size=[n_src_vocab, d_model], padding_idx=0, + param_attr=fluid.initializer.Normal(loc=0.0, scale=1.0)) self.pos_inp = get_sinusoid_encoding_table(n_position, d_model, padding_idx=0) self.position_enc = dg.Embedding(size=[n_position, d_model], - padding_idx=0, param_attr=fluid.ParamAttr( initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp), trainable=False)) @@ -30,7 +31,7 @@ class Encoder(dg.Layer): for i, layer in enumerate(self.layer_stack): self.add_sublayer('fft_{}'.format(i), layer) - def forward(self, character, text_pos): + def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): """ Encoder layer of FastSpeech. @@ -46,10 +47,7 @@ class Encoder(dg.Layer): enc_slf_attn_list (list), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list. """ enc_slf_attn_list = [] - # -- prepare masks - # shape character (N, T) - slf_attn_mask = get_attn_key_pad_mask(seq_k=character, seq_q=character) - non_pad_mask = get_non_pad_mask(character) + slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) # -- Forward enc_output = self.src_word_emb(character) + self.position_enc(text_pos) #(N, T, C) @@ -61,4 +59,4 @@ class Encoder(dg.Layer): slf_attn_mask=slf_attn_mask) enc_slf_attn_list += [enc_slf_attn] - return enc_output, non_pad_mask, enc_slf_attn_list \ No newline at end of file + return enc_output, enc_slf_attn_list \ No newline at end of file diff --git a/parakeet/models/fastspeech/fastspeech.py b/parakeet/models/fastspeech/fastspeech.py index 4a01b95..c6edb50 100644 --- a/parakeet/models/fastspeech/fastspeech.py +++ b/parakeet/models/fastspeech/fastspeech.py @@ -1,7 +1,9 @@ import math +import numpy as np import paddle.fluid.dygraph as dg import paddle.fluid as fluid from parakeet.g2p.text.symbols import symbols +from parakeet.models.transformer_tts.utils import * from parakeet.models.transformer_tts.post_convnet import PostConvNet from parakeet.models.fastspeech.length_regulator import LengthRegulator from parakeet.models.fastspeech.encoder import Encoder @@ -54,7 +56,9 @@ class FastSpeech(dg.Layer): dropout=0.1, batchnorm_last=True) - def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0): + def forward(self, character, text_pos, enc_non_pad_mask, dec_non_pad_mask, + enc_slf_attn_mask=None, dec_slf_attn_mask=None, + mel_pos=None, length_target=None, alpha=1.0): """ FastSpeech model. @@ -80,13 +84,15 @@ class FastSpeech(dg.Layer): dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. """ - encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(character, text_pos) + encoder_output, enc_slf_attn_list = self.encoder(character, text_pos, enc_non_pad_mask, slf_attn_mask=enc_slf_attn_mask) if fluid.framework._dygraph_tracer()._train_mode: length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output, target=length_target, alpha=alpha) - decoder_output, dec_slf_attn_list = self.decoder(length_regulator_output, mel_pos) + decoder_output, dec_slf_attn_list = self.decoder(length_regulator_output, mel_pos, + dec_non_pad_mask, + slf_attn_mask=dec_slf_attn_mask) mel_output = self.mel_linear(decoder_output) mel_output_postnet = self.postnet(mel_output) + mel_output @@ -94,7 +100,12 @@ class FastSpeech(dg.Layer): return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list else: length_regulator_output, decoder_pos = self.length_regulator(encoder_output, alpha=alpha) - decoder_output, _ = self.decoder(length_regulator_output, decoder_pos) + slf_attn_mask = get_triu_tensor(decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32) + slf_attn_mask = fluid.layers.cast(dg.to_variable(slf_attn_mask == 0), np.float32) + slf_attn_mask = dg.to_variable(slf_attn_mask) + dec_non_pad_mask = fluid.layers.unsqueeze((decoder_pos != 0).astype(np.float32), [-1]) + decoder_output, _ = self.decoder(length_regulator_output, decoder_pos, dec_non_pad_mask, + slf_attn_mask=slf_attn_mask) mel_output = self.mel_linear(decoder_output) mel_output_postnet = self.postnet(mel_output) + mel_output diff --git a/parakeet/models/fastspeech/fft_block.py b/parakeet/models/fastspeech/fft_block.py index ea86328..82c9d0f 100644 --- a/parakeet/models/fastspeech/fft_block.py +++ b/parakeet/models/fastspeech/fft_block.py @@ -12,7 +12,7 @@ class FFTBlock(dg.Layer): self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, is_bias=True, dropout=dropout, is_concat=False) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout) - def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): + def forward(self, enc_input, non_pad_mask, slf_attn_mask=None): """ Feed Forward Transformer block in FastSpeech. @@ -28,6 +28,7 @@ class FFTBlock(dg.Layer): slf_attn (Variable), Shape(B * n_head, T, T), the self attention. """ output, slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) + output *= non_pad_mask output = self.pos_ffn(output) diff --git a/parakeet/models/fastspeech/length_regulator.py b/parakeet/models/fastspeech/length_regulator.py index d90eaa5..3ad88ff 100644 --- a/parakeet/models/fastspeech/length_regulator.py +++ b/parakeet/models/fastspeech/length_regulator.py @@ -121,11 +121,11 @@ class DurationPredictor(dg.Layer): out = layers.transpose(encoder_output, [0,2,1]) out = self.conv1(out) out = layers.transpose(out, [0,2,1]) - out = layers.dropout(layers.relu(self.layer_norm1(out)), self.dropout) + out = layers.dropout(layers.relu(self.layer_norm1(out)), self.dropout, dropout_implementation='upscale_in_train') out = layers.transpose(out, [0,2,1]) out = self.conv2(out) out = layers.transpose(out, [0,2,1]) - out = layers.dropout(layers.relu(self.layer_norm2(out)), self.dropout) + out = layers.dropout(layers.relu(self.layer_norm2(out)), self.dropout, dropout_implementation='upscale_in_train') out = layers.relu(self.linear(out)) out = layers.squeeze(out, axes=[-1]) diff --git a/parakeet/models/fastspeech/utils.py b/parakeet/models/fastspeech/utils.py index a94de8d..b1fff09 100644 --- a/parakeet/models/fastspeech/utils.py +++ b/parakeet/models/fastspeech/utils.py @@ -14,7 +14,7 @@ def get_alignment(attn_probs, mel_lens, n_head): max_F = F max_attn = attn alignment = compute_duration(max_attn, mel_lens) - return alignment + return alignment, max_attn def score_F(attn): max = np.max(attn, axis=-1) diff --git a/parakeet/models/transformer_tts/cbhg.py b/parakeet/models/transformer_tts/cbhg.py index 94b907f..29a5a9a 100644 --- a/parakeet/models/transformer_tts/cbhg.py +++ b/parakeet/models/transformer_tts/cbhg.py @@ -124,6 +124,7 @@ class CBHG(dg.Layer): conv_list = [] conv_input = input_ + for i, (conv, batchnorm) in enumerate(zip(self.conv_list, self.batchnorm_list)): conv_input = self._conv_fit_dim(conv(conv_input), i+1) conv_input = layers.relu(batchnorm(conv_input)) diff --git a/parakeet/models/transformer_tts/decoder.py b/parakeet/models/transformer_tts/decoder.py index b0da788..970557c 100644 --- a/parakeet/models/transformer_tts/decoder.py +++ b/parakeet/models/transformer_tts/decoder.py @@ -1,7 +1,7 @@ import math import paddle.fluid.dygraph as dg import paddle.fluid as fluid -from parakeet.modules.utils import * +from parakeet.models.transformer_tts.utils import * from parakeet.modules.multihead_attention import MultiheadAttention from parakeet.modules.ffn import PositionwiseFeedForward from parakeet.models.transformer_tts.prenet import PreNet @@ -11,6 +11,7 @@ class Decoder(dg.Layer): def __init__(self, num_hidden, config, num_head=4): super(Decoder, self).__init__() self.num_hidden = num_hidden + self.num_head = num_head param = fluid.ParamAttr() self.alpha = self.create_parameter(shape=(1,), attr=param, dtype='float32', default_initializer = fluid.initializer.ConstantInitializer(value=1.0)) @@ -48,25 +49,21 @@ class Decoder(dg.Layer): self.postconvnet = PostConvNet(config['audio']['num_mels'], config['hidden_size'], filter_size = 5, padding = 4, num_conv=5, outputs_per_step=config['audio']['outputs_per_step'], - use_cudnn = True) + use_cudnn=True) - def forward(self, key, value, query, c_mask, positional): + def forward(self, key, value, query, positional, mask, m_mask=None, m_self_mask=None, zero_mask=None): # get decoder mask with triangular matrix if fluid.framework._dygraph_tracer()._train_mode: - m_mask = get_non_pad_mask(positional) - mask = get_attn_key_pad_mask((positional==0).astype(np.float32), query) - triu_tensor = dg.to_variable(get_triu_tensor(query.numpy(), query.numpy())).astype(np.float32) - mask = mask + triu_tensor - mask = fluid.layers.cast(mask == 0, np.float32) - - # (batch_size, decoder_len, encoder_len) - zero_mask = get_attn_key_pad_mask(layers.squeeze(c_mask,[-1]), query) + m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]]) + m_self_mask = layers.expand(m_self_mask, [self.num_head, 1, query.shape[1]]) + mask = layers.expand(mask, [self.num_head, 1, 1]) + zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1]) + else: - mask = get_triu_tensor(query.numpy(), query.numpy()).astype(np.float32) - mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32) - m_mask, zero_mask = None, None + m_mask, m_self_mask, zero_mask = None, None, None + # Decoder pre-network query = self.decoder_prenet(query) @@ -79,18 +76,22 @@ class Decoder(dg.Layer): query = positional * self.alpha + query #positional dropout - query = fluid.layers.dropout(query, 0.1) + query = fluid.layers.dropout(query, 0.1, dropout_implementation='upscale_in_train') + # Attention decoder-decoder, encoder-decoder selfattn_list = list() attn_list = list() + for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers, self.ffns): - query, attn_dec = selfattn(query, query, query, mask = mask, query_mask = m_mask) + query, attn_dec = selfattn(query, query, query, mask = mask, query_mask = m_self_mask) query, attn_dot = attn(key, value, query, mask = zero_mask, query_mask = m_mask) query = ffn(query) selfattn_list.append(attn_dec) attn_list.append(attn_dot) + + # Mel linear projection mel_out = self.mel_linear(query) # Post Mel Network diff --git a/parakeet/models/transformer_tts/encoder.py b/parakeet/models/transformer_tts/encoder.py index 8cd37b2..84926c7 100644 --- a/parakeet/models/transformer_tts/encoder.py +++ b/parakeet/models/transformer_tts/encoder.py @@ -9,11 +9,11 @@ class Encoder(dg.Layer): def __init__(self, embedding_size, num_hidden, num_head=4): super(Encoder, self).__init__() self.num_hidden = num_hidden + self.num_head = num_head param = fluid.ParamAttr(initializer=fluid.initializer.Constant(value=1.0)) self.alpha = self.create_parameter(shape=(1, ), attr=param, dtype='float32') self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0) self.pos_emb = dg.Embedding(size=[1024, num_hidden], - padding_idx=0, param_attr=fluid.ParamAttr( initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp), trainable=False)) @@ -23,17 +23,20 @@ class Encoder(dg.Layer): self.layers = [MultiheadAttention(num_hidden, num_hidden//num_head, num_hidden//num_head) for _ in range(3)] for i, layer in enumerate(self.layers): self.add_sublayer("self_attn_{}".format(i), layer) - self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*num_head, filter_size=1, use_cudnn = True) for _ in range(3)] + self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*num_head, filter_size=1, use_cudnn=True) for _ in range(3)] for i, layer in enumerate(self.ffns): self.add_sublayer("ffns_{}".format(i), layer) - def forward(self, x, positional): + def forward(self, x, positional, mask=None, query_mask=None): + if fluid.framework._dygraph_tracer()._train_mode: - query_mask = get_non_pad_mask(positional) - mask = get_attn_key_pad_mask(positional, x) + seq_len_key = x.shape[1] + query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key]) + mask = layers.expand(mask, [self.num_head, 1, 1]) else: query_mask, mask = None, None + # Encoder pre_network x = self.encoder_prenet(x) #(N,T,C) @@ -43,9 +46,9 @@ class Encoder(dg.Layer): x = positional * self.alpha + x #(N, T, C) - + # Positional dropout - x = layers.dropout(x, 0.1) + x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train') # Self attention encoder attentions = list() @@ -54,4 +57,4 @@ class Encoder(dg.Layer): x = ffn(x) attentions.append(attention) - return x, query_mask, attentions \ No newline at end of file + return x, attentions \ No newline at end of file diff --git a/parakeet/models/transformer_tts/encoderprenet.py b/parakeet/models/transformer_tts/encoderprenet.py index b27f2fe..76c820c 100644 --- a/parakeet/models/transformer_tts/encoderprenet.py +++ b/parakeet/models/transformer_tts/encoderprenet.py @@ -14,7 +14,8 @@ class EncoderPrenet(dg.Layer): self.num_hidden = num_hidden self.use_cudnn = use_cudnn self.embedding = dg.Embedding( size = [len(symbols), embedding_size], - padding_idx = None) + padding_idx = 0, + param_attr=fluid.initializer.Normal(loc=0.0, scale=1.0)) self.conv_list = [] k = math.sqrt(1 / embedding_size) self.conv_list.append(Conv1D(num_channels = embedding_size, @@ -49,10 +50,12 @@ class EncoderPrenet(dg.Layer): bias_attr=fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k))) def forward(self, x): + x = self.embedding(x) #(batch_size, seq_len, embending_size) x = layers.transpose(x,[0,2,1]) for batch_norm, conv in zip(self.batch_norm_list, self.conv_list): - x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2) + x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2, + dropout_implementation='upscale_in_train') x = layers.transpose(x,[0,2,1]) #(N,T,C) x = self.projection(x) diff --git a/parakeet/models/transformer_tts/post_convnet.py b/parakeet/models/transformer_tts/post_convnet.py index 3e393ee..ff4a558 100644 --- a/parakeet/models/transformer_tts/post_convnet.py +++ b/parakeet/models/transformer_tts/post_convnet.py @@ -76,11 +76,13 @@ class PostConvNet(dg.Layer): batch_norm = self.batch_norm_list[i] conv = self.conv_list[i] - input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout) + input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout, + dropout_implementation='upscale_in_train') conv = self.conv_list[self.num_conv-1] input = conv(input)[:,:,:len] if self.batchnorm_last: batch_norm = self.batch_norm_list[self.num_conv-1] - input = layers.dropout(batch_norm(input), self.dropout) + input = layers.dropout(batch_norm(input), self.dropout, + dropout_implementation='upscale_in_train') output = layers.transpose(input, [0,2,1]) return output \ No newline at end of file diff --git a/parakeet/models/transformer_tts/prenet.py b/parakeet/models/transformer_tts/prenet.py index e9b0667..92dba45 100644 --- a/parakeet/models/transformer_tts/prenet.py +++ b/parakeet/models/transformer_tts/prenet.py @@ -34,6 +34,6 @@ class PreNet(dg.Layer): Returns: x (Variable), Shape(B, T, C), the result after pernet. """ - x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate) - x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate) + x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate, dropout_implementation='upscale_in_train') + x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate, dropout_implementation='upscale_in_train') return x diff --git a/parakeet/models/transformer_tts/transformer_tts.py b/parakeet/models/transformer_tts/transformer_tts.py index bf2924a..fd33e9b 100644 --- a/parakeet/models/transformer_tts/transformer_tts.py +++ b/parakeet/models/transformer_tts/transformer_tts.py @@ -10,15 +10,14 @@ class TransformerTTS(dg.Layer): self.decoder = Decoder(config['hidden_size'], config) self.config = config - def forward(self, characters, mel_input, pos_text, pos_mel): - - key, c_mask, attns_enc = self.encoder(characters, pos_text) - - mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(key, key, mel_input, c_mask, pos_mel) - + def forward(self, characters, mel_input, pos_text, pos_mel, dec_slf_mask, enc_slf_mask=None, enc_query_mask=None, enc_dec_mask=None, dec_query_slf_mask=None, dec_query_mask=None): + key, attns_enc = self.encoder(characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) + + mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(key, key, mel_input, pos_mel, + mask=dec_slf_mask, zero_mask=enc_dec_mask, + m_self_mask=dec_query_slf_mask, m_mask=dec_query_mask ) return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec - diff --git a/parakeet/models/transformer_tts/utils.py b/parakeet/models/transformer_tts/utils.py index ab575f9..f6f567b 100644 --- a/parakeet/models/transformer_tts/utils.py +++ b/parakeet/models/transformer_tts/utils.py @@ -35,7 +35,9 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): return sinusoid_table def get_non_pad_mask(seq): - return layers.unsqueeze((seq != 0).astype(np.float32),[-1]) + mask = (seq != 0).astype(np.float32) + mask = np.expand_dims(mask, axis=-1) + return mask def get_attn_key_pad_mask(seq_k, seq_q): ''' For masking out the padding part of key sequence. ''' @@ -43,7 +45,21 @@ def get_attn_key_pad_mask(seq_k, seq_q): # Expand to fit the shape of key query attention matrix. len_q = seq_q.shape[1] padding_mask = (seq_k != 0).astype(np.float32) - padding_mask = layers.expand(layers.unsqueeze(padding_mask,[1]), [1, len_q, 1]) + padding_mask = np.expand_dims(padding_mask, axis=1) + padding_mask = padding_mask.repeat([len_q],axis=1) + padding_mask = (padding_mask == 0).astype(np.float32) * (-2 ** 32 + 1) + return padding_mask + +def get_dec_attn_key_pad_mask(seq_k, seq_q): + ''' For masking out the padding part of key sequence. ''' + + # Expand to fit the shape of key query attention matrix. + len_q = seq_q.shape[1] + padding_mask = (seq_k == 0).astype(np.float32) + padding_mask = np.expand_dims(padding_mask, axis=1) + triu_tensor = get_triu_tensor(seq_q, seq_q) + padding_mask = padding_mask.repeat([len_q],axis=1) + triu_tensor + padding_mask = (padding_mask != 0).astype(np.float32) * (-2 ** 32 + 1) return padding_mask def get_triu_tensor(seq_k, seq_q): diff --git a/parakeet/modules/dynamic_gru.py b/parakeet/modules/dynamic_gru.py index e84c598..acf420f 100644 --- a/parakeet/modules/dynamic_gru.py +++ b/parakeet/modules/dynamic_gru.py @@ -40,10 +40,10 @@ class DynamicGRU(dg.Layer): i = inputs.shape[1] - 1 - i input_ = inputs[:, i:i + 1, :] input_ = layers.reshape( - input_, [-1, input_.shape[2]], inplace=False) + input_, [-1, input_.shape[2]]) hidden, reset, gate = self.gru_unit(input_, hidden) hidden_ = layers.reshape( - hidden, [-1, 1, hidden.shape[1]], inplace=False) + hidden, [-1, 1, hidden.shape[1]]) res.append(hidden_) if self.is_reverse: res = res[::-1] diff --git a/parakeet/modules/ffn.py b/parakeet/modules/ffn.py index dc413bf..6f2a454 100644 --- a/parakeet/modules/ffn.py +++ b/parakeet/modules/ffn.py @@ -45,7 +45,7 @@ class PositionwiseFeedForward(dg.Layer): x = self.w_2(layers.relu(self.w_1(x))) # dropout - x = layers.dropout(x, self.dropout) + x = layers.dropout(x, self.dropout, dropout_implementation='upscale_in_train') x = layers.transpose(x, [0,2,1]) # residual connection diff --git a/parakeet/modules/modules.py b/parakeet/modules/modules.py index 7aef463..cc9c207 100644 --- a/parakeet/modules/modules.py +++ b/parakeet/modules/modules.py @@ -211,7 +211,7 @@ class Conv1DGLU(dg.Layer): residual = x x = fluid.layers.dropout( - x, self.dropout, dropout_implementation="upscale_in_train") + x, self.dropout) x = self.conv(x) content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) @@ -241,7 +241,7 @@ class Conv1DGLU(dg.Layer): # add step input and produce step output x = fluid.layers.dropout( - x, self.dropout, dropout_implementation="upscale_in_train") + x, self.dropout) x = self.conv.add_input(x) content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) diff --git a/parakeet/modules/multihead_attention.py b/parakeet/modules/multihead_attention.py index 40d8164..edfa193 100644 --- a/parakeet/modules/multihead_attention.py +++ b/parakeet/modules/multihead_attention.py @@ -47,17 +47,13 @@ class ScaledDotProductAttention(dg.Layer): attention (Variable), Shape(n_head * B, T, C), the attention of key. """ # Compute attention score - attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y - attention = attention / math.sqrt(self.d_key) + attention = layers.matmul(query, key, transpose_y=True, alpha=self.d_key**-0.5) #transpose the last dim in y # Mask key to ignore padding if mask is not None: - attention = attention * mask - mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1) attention = attention + mask - attention = layers.softmax(attention) - attention = layers.dropout(attention, dropout) + attention = layers.dropout(attention, dropout, dropout_implementation='upscale_in_train') # Mask query to ignore padding if query_mask is not None: @@ -103,15 +99,10 @@ class MultiheadAttention(dg.Layer): result (Variable), Shape(B, T, C), the result of mutihead attention. attention (Variable), Shape(n_head * B, T, C), the attention of key. """ + batch_size = key.shape[0] seq_len_key = key.shape[1] seq_len_query = query_input.shape[1] - - # repeat masks h times - if query_mask is not None: - query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key]) - if mask is not None: - mask = layers.expand(mask, (self.num_head, 1, 1)) # Make multihead attention @@ -123,15 +114,15 @@ class MultiheadAttention(dg.Layer): key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k]) value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k]) query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q]) - + result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask) - + # concat all multihead result result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q]) result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1]) if self.is_concat: result = layers.concat([query_input,result], axis=-1) - result = layers.dropout(self.fc(result), self.dropout) + result = layers.dropout(self.fc(result), self.dropout, dropout_implementation='upscale_in_train') result = result + query_input result = self.layer_norm(result) From 54bd75962563d4b9ce84ef8e2555f9766170a7e6 Mon Sep 17 00:00:00 2001 From: lifuchen Date: Fri, 6 Mar 2020 02:47:16 +0000 Subject: [PATCH 2/4] modified some vars name --- examples/fastspeech/README.md | 2 +- examples/fastspeech/train.py | 8 ++++---- examples/transformer_tts/README.md | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/fastspeech/README.md b/examples/fastspeech/README.md index db42c3c..5f88105 100644 --- a/examples/fastspeech/README.md +++ b/examples/fastspeech/README.md @@ -1,5 +1,5 @@ # Fastspeech -Paddle fluid implementation of Fastspeech, a feed-forward network based on Transformer. The implementation is based on [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263). +PaddlePaddle fluid implementation of Fastspeech, a feed-forward network based on Transformer. The implementation is based on [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263). ## Dataset diff --git a/examples/fastspeech/train.py b/examples/fastspeech/train.py index f4338ed..7565ac9 100644 --- a/examples/fastspeech/train.py +++ b/examples/fastspeech/train.py @@ -67,12 +67,12 @@ def main(args): with dg.guard(place): with fluid.unique_name.guard(): - transformerTTS = TransformerTTS(cfg) + transformer_tts = TransformerTTS(cfg) model_dict, _ = load_checkpoint( str(args.transformer_step), os.path.join(args.transtts_path, "transformer")) - transformerTTS.set_dict(model_dict) - transformerTTS.eval() + transformer_tts.set_dict(model_dict) + transformer_tts.eval() model = FastSpeech(cfg) model.train() @@ -105,7 +105,7 @@ def main(args): mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data - _, _, attn_probs, _, _, _ = transformerTTS( + _, _, attn_probs, _, _, _ = transformer_tts( character, mel_input, pos_text, diff --git a/examples/transformer_tts/README.md b/examples/transformer_tts/README.md index d7badad..8c766ad 100644 --- a/examples/transformer_tts/README.md +++ b/examples/transformer_tts/README.md @@ -1,5 +1,5 @@ # TransformerTTS -PaddlePaddle fluid implementation of TransformerTTS, a neural TTS with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895). +PaddlePaddle fluid implementation of TransformerTTS, a neural TTS model with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895). ## Dataset From 8083da21acbf877f2800ff51781fb4ee5eee75d4 Mon Sep 17 00:00:00 2001 From: liuyibing01 Date: Sat, 7 Mar 2020 14:21:35 +0000 Subject: [PATCH 3/4] Fix sample file name --- examples/waveflow/README.md | 4 ++-- parakeet/models/waveflow/waveflow.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/waveflow/README.md b/examples/waveflow/README.md index e21039a..d36f0f3 100644 --- a/examples/waveflow/README.md +++ b/examples/waveflow/README.md @@ -4,7 +4,7 @@ PaddlePaddle dynamic graph implementation of [WaveFlow: A Compact Flow-based Mod - WaveFlow can synthesize 22.05 kHz high-fidelity speech around 40x faster than real-time on a Nvidia V100 GPU without engineered inference kernels, which is faster than [WaveGlow] (https://github.com/NVIDIA/waveglow) and serveral orders of magnitude faster than WaveNet. - WaveFlow is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M) and comparable to WaveNet (4.6M). -- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. +- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. ## Project Structure ```text @@ -99,7 +99,7 @@ python -u synthesis.py \ --sigma=1.0 ``` -In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. +In this example, `--output` specifies where to save the synthesized audios and `--sample` (<16) specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. ### Benchmarking diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index a8bd8af..4ef1411 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -179,10 +179,13 @@ class WaveFlow(): mels_list = [mels for _, mels in self.validloader()] if sample is not None: mels_list = [mels_list[sample]] + else: + sample = 0 - for sample, mel in enumerate(mels_list): - filename = "{}/valid_{}.wav".format(output, sample) - print("Synthesize sample {}, save as {}".format(sample, filename)) + for idx, mel in enumerate(mels_list): + abs_idx = sample + idx + filename = "{}/valid_{}.wav".format(output, abs_idx) + print("Synthesize sample {}, save as {}".format(abs_idx, filename)) start_time = time.time() audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) From 4f7ded3c89081e93f129323e59e99e5f0946f2cf Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 7 Mar 2020 23:25:04 -0800 Subject: [PATCH 4/4] add docstring --- examples/waveflow/utils.py | 49 ++++++++++++++ parakeet/models/waveflow/data.py | 1 + parakeet/models/waveflow/waveflow.py | 68 ++++++++++++++++++++ parakeet/models/waveflow/waveflow_modules.py | 38 +++++++++++ 4 files changed, 156 insertions(+) diff --git a/examples/waveflow/utils.py b/examples/waveflow/utils.py index da9b4ba..b899073 100644 --- a/examples/waveflow/utils.py +++ b/examples/waveflow/utils.py @@ -109,6 +109,16 @@ def add_yaml_config(config): def load_latest_checkpoint(checkpoint_dir, rank=0): + """Get the iteration number corresponding to the latest saved checkpoint + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + rank (int, optional): the rank of the process in multi-process setting. + Defaults to 0. + + Returns: + int: the latest iteration number. + """ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") # Create checkpoint index file if not exist. if (not os.path.isfile(checkpoint_path)) and rank == 0: @@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0): def save_latest_checkpoint(checkpoint_dir, iteration): + """Save the iteration number of the latest model to be checkpointed. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + + Returns: + None + """ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") # Update the latest checkpoint index. with open(checkpoint_path, "w") as handle: @@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir, iteration=None, file_path=None, dtype="float32"): + """Load a specific model checkpoint from disk. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + rank (int): the rank of the process in multi-process setting. + model (obj): model to load parameters. + optimizer (obj, optional): optimizer to load states if needed. + Defaults to None. + iteration (int, optional): if specified, load the specific checkpoint, + if not specified, load the latest one. Defaults to None. + file_path (str, optional): if specified, load the checkpoint + stored in the file_path. Defaults to None. + dtype (str, optional): precision of the model parameters. + Defaults to float32. + + Returns: + None + """ if file_path is None: if iteration is None: iteration = load_latest_checkpoint(checkpoint_dir, rank) @@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir, def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + """Checkpoint the latest trained model parameters. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + model (obj): model to be checkpointed. + optimizer (obj, optional): optimizer to be checkpointed. + Defaults to None. + + Returns: + None + """ file_path = "{}/step-{}".format(checkpoint_dir, iteration) model_dict = model.state_dict() dg.save_dygraph(model_dict, file_path) diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py index 83438f7..33e2ee5 100644 --- a/parakeet/models/waveflow/data.py +++ b/parakeet/models/waveflow/data.py @@ -80,6 +80,7 @@ class Subset(DatasetMixin): # whole audio for valid set pass else: + # Randomly crop segment_length from audios in the training set. # audio shape: [len] if audio.shape[0] >= segment_length: max_audio_start = audio.shape[0] - segment_length diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index 4ef1411..101bb66 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule class WaveFlow(): + """Wrapper class of WaveFlow model that supports multiple APIs. + + This module provides APIs for model building, training, validation, + inference, benchmarking, and saving. + + Args: + config (obj): config info. + checkpoint_dir (str): path for checkpointing. + parallel (bool, optional): whether use multiple GPUs for training. + Defaults to False. + rank (int, optional): the rank of the process in a multi-process + scenario. Defaults to 0. + nranks (int, optional): the total number of processes. Defaults to 1. + tb_logger (obj, optional): logger to visualize metrics. + Defaults to None. + + Returns: + WaveFlow + """ def __init__(self, config, checkpoint_dir, @@ -44,6 +63,15 @@ class WaveFlow(): self.dtype = "float16" if config.use_fp16 else "float32" def build(self, training=True): + """Initialize the model. + + Args: + training (bool, optional): Whether the model is built for training or inference. + Defaults to True. + + Returns: + None + """ config = self.config dataset = LJSpeech(config, self.nranks, self.rank) self.trainloader = dataset.trainloader @@ -99,6 +127,14 @@ class WaveFlow(): self.waveflow = waveflow def train_step(self, iteration): + """Train the model for one step. + + Args: + iteration (int): current iteration number. + + Returns: + None + """ self.waveflow.train() start_time = time.time() @@ -135,6 +171,14 @@ class WaveFlow(): @dg.no_grad def valid_step(self, iteration): + """Run the model on the validation dataset. + + Args: + iteration (int): current iteration number. + + Returns: + None + """ self.waveflow.eval() tb = self.tb_logger @@ -167,6 +211,14 @@ class WaveFlow(): @dg.no_grad def infer(self, iteration): + """Run the model to synthesize audios. + + Args: + iteration (int): iteration number of the loaded checkpoint. + + Returns: + None + """ self.waveflow.eval() config = self.config @@ -203,6 +255,14 @@ class WaveFlow(): @dg.no_grad def benchmark(self): + """Run the model to benchmark synthesis speed. + + Args: + None + + Returns: + None + """ self.waveflow.eval() mels_list = [mels for _, mels in self.validloader()] @@ -223,6 +283,14 @@ class WaveFlow(): print("{} X real-time".format(audio_time / syn_time)) def save(self, iteration): + """Save model checkpoint. + + Args: + iteration (int): iteration number of the model to be saved. + + Returns: + None + """ utils.save_latest_parameters(self.checkpoint_dir, iteration, self.waveflow, self.optimizer) utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 46dfba7..f480cd9 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -293,6 +293,14 @@ class Flow(dg.Layer): class WaveFlowModule(dg.Layer): + """WaveFlow model implementation. + + Args: + config (obj): model configuration parameters. + + Returns: + WaveFlowModule + """ def __init__(self, config): super(WaveFlowModule, self).__init__() self.n_flows = config.n_flows @@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer): self.perms.append(perm) def forward(self, audio, mel): + """Training forward pass. + + Use a conditioner to upsample mel spectrograms into hidden states. + These hidden states along with the audio are passed to a stack of Flow + modules to obtain the final latent variable z and a list of log scaling + variables, which are then passed to the WaveFlowLoss module to calculate + the negative log likelihood. + + Args: + audio (obj): audio samples. + mel (obj): mel spectrograms. + + Returns: + z (obj): latent variable. + log_s_list(list): list of log scaling variables. + """ mel = self.conditioner(mel) assert mel.shape[2] >= audio.shape[1] # Prune out the tail of audio/mel so that time/n_group == 0. @@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer): return z, log_s_list def synthesize(self, mel, sigma=1.0): + """Use model to synthesize waveform. + + Use a conditioner to upsample mel spectrograms into hidden states. + These hidden states along with initial random gaussian latent variable + are passed to a stack of Flow modules to obtain the audio output. + + Args: + mel (obj): mel spectrograms. + sigma (float, optional): standard deviation of the guassian latent + variable. Defaults to 1.0. + + Returns: + audio (obj): synthesized audio. + """ if self.dtype == "float16": mel = fluid.layers.cast(mel, self.dtype) mel = self.conditioner.infer(mel)