From ad4b248af84ecca03510761b93e507f3032a3896 Mon Sep 17 00:00:00 2001 From: lifuchen Date: Tue, 7 Apr 2020 09:34:27 +0000 Subject: [PATCH] fix some bug of mask in fastspeech --- examples/fastspeech/synthesis.py | 16 +++++++++++++--- examples/fastspeech/synthesis.sh | 2 +- parakeet/models/fastspeech/decoder.py | 3 ++- parakeet/models/fastspeech/fastspeech.py | 1 + parakeet/models/transformer_tts/decoder.py | 1 + 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py index 774a67f..8eb0328 100644 --- a/examples/fastspeech/synthesis.py +++ b/examples/fastspeech/synthesis.py @@ -18,6 +18,7 @@ import argparse from parse import add_config_options_to_parser from pprint import pprint from ruamel import yaml +from matplotlib import cm import numpy as np import paddle.fluid as fluid import paddle.fluid.dygraph as dg @@ -64,8 +65,7 @@ def synthesis(text_input, args): pos_text = np.arange(1, text.shape[1] + 1) pos_text = np.expand_dims(pos_text, axis=0) enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32) - enc_slf_attn_mask = get_attn_key_pad_mask(pos_text, - text).astype(np.float32) + enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32) text = dg.to_variable(text) pos_text = dg.to_variable(pos_text) @@ -101,8 +101,17 @@ def synthesis(text_input, args): do_trim_silence=False, sound_norm=False) + np.save('mel_output', mel_output_postnet.numpy()) mel_output_postnet = fluid.layers.transpose( fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0]) + x = np.uint8(cm.viridis(mel_output_postnet.numpy()) * 255) + writer.add_image('mel_0_0', x, 0, dataformats="HWC") + ground_truth = _ljspeech_processor.load_wav( + str('/paddle/Parakeet/dataset/LJSpeech-1.1/wavs/LJ001-0175.wav')) + ground_truth = _ljspeech_processor.melspectrogram(ground_truth).astype( + np.float32) + x = np.uint8(cm.viridis(ground_truth) * 255) + writer.add_image('mel_gt_0', x, 0, dataformats="HWC") wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy( )) writer.add_audio(text_input, wav, 0, cfg['audio']['sr']) @@ -114,4 +123,5 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description="Train Fastspeech model") add_config_options_to_parser(parser) args = parser.parse_args() - synthesis("Transformer model is so fast!", args) + synthesis("Simple as this proposition is, it is necessary to be stated,", + args) diff --git a/examples/fastspeech/synthesis.sh b/examples/fastspeech/synthesis.sh index 4daef57..c74df24 100644 --- a/examples/fastspeech/synthesis.sh +++ b/examples/fastspeech/synthesis.sh @@ -4,7 +4,7 @@ python -u synthesis.py \ --use_gpu=1 \ --alpha=1.0 \ --checkpoint_path='checkpoint/' \ ---fastspeech_step=71000 \ +--fastspeech_step=89000 \ --log_dir='./log' \ --config_path='configs/synthesis.yaml' \ diff --git a/parakeet/models/fastspeech/decoder.py b/parakeet/models/fastspeech/decoder.py index 30432d0..4c5768c 100644 --- a/parakeet/models/fastspeech/decoder.py +++ b/parakeet/models/fastspeech/decoder.py @@ -88,7 +88,8 @@ class Decoder(dg.Layer): dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list. """ dec_slf_attn_list = [] - slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) + if slf_attn_mask: + slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) # -- Forward dec_output = enc_seq + self.position_enc(enc_pos) diff --git a/parakeet/models/fastspeech/fastspeech.py b/parakeet/models/fastspeech/fastspeech.py index 96d5074..a590b56 100644 --- a/parakeet/models/fastspeech/fastspeech.py +++ b/parakeet/models/fastspeech/fastspeech.py @@ -142,6 +142,7 @@ class FastSpeech(dg.Layer): encoder_output, alpha=alpha) slf_attn_mask = get_triu_tensor( decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32) + slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0) slf_attn_mask = fluid.layers.cast( dg.to_variable(slf_attn_mask == 0), np.float32) slf_attn_mask = dg.to_variable(slf_attn_mask) diff --git a/parakeet/models/transformer_tts/decoder.py b/parakeet/models/transformer_tts/decoder.py index 4275a56..c65280b 100644 --- a/parakeet/models/transformer_tts/decoder.py +++ b/parakeet/models/transformer_tts/decoder.py @@ -149,6 +149,7 @@ class Decoder(dg.Layer): zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1]) else: + mask = layers.expand(mask, [self.num_head, 1, 1]) m_mask, m_self_mask, zero_mask = None, None, None # Decoder pre-network