fix some bug of mask in fastspeech

This commit is contained in:
lifuchen 2020-04-07 09:34:27 +00:00
parent 75d464221c
commit ad4b248af8
5 changed files with 18 additions and 5 deletions

View File

@ -18,6 +18,7 @@ import argparse
from parse import add_config_options_to_parser
from pprint import pprint
from ruamel import yaml
from matplotlib import cm
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.dygraph as dg
@ -64,8 +65,7 @@ def synthesis(text_input, args):
pos_text = np.arange(1, text.shape[1] + 1)
pos_text = np.expand_dims(pos_text, axis=0)
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text,
text).astype(np.float32)
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32)
text = dg.to_variable(text)
pos_text = dg.to_variable(pos_text)
@ -101,8 +101,17 @@ def synthesis(text_input, args):
do_trim_silence=False,
sound_norm=False)
np.save('mel_output', mel_output_postnet.numpy())
mel_output_postnet = fluid.layers.transpose(
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
x = np.uint8(cm.viridis(mel_output_postnet.numpy()) * 255)
writer.add_image('mel_0_0', x, 0, dataformats="HWC")
ground_truth = _ljspeech_processor.load_wav(
str('/paddle/Parakeet/dataset/LJSpeech-1.1/wavs/LJ001-0175.wav'))
ground_truth = _ljspeech_processor.melspectrogram(ground_truth).astype(
np.float32)
x = np.uint8(cm.viridis(ground_truth) * 255)
writer.add_image('mel_gt_0', x, 0, dataformats="HWC")
wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy(
))
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
@ -114,4 +123,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train Fastspeech model")
add_config_options_to_parser(parser)
args = parser.parse_args()
synthesis("Transformer model is so fast!", args)
synthesis("Simple as this proposition is, it is necessary to be stated,",
args)

View File

@ -4,7 +4,7 @@ python -u synthesis.py \
--use_gpu=1 \
--alpha=1.0 \
--checkpoint_path='checkpoint/' \
--fastspeech_step=71000 \
--fastspeech_step=89000 \
--log_dir='./log' \
--config_path='configs/synthesis.yaml' \

View File

@ -88,7 +88,8 @@ class Decoder(dg.Layer):
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
"""
dec_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
if slf_attn_mask:
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
# -- Forward
dec_output = enc_seq + self.position_enc(enc_pos)

View File

@ -142,6 +142,7 @@ class FastSpeech(dg.Layer):
encoder_output, alpha=alpha)
slf_attn_mask = get_triu_tensor(
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
slf_attn_mask = fluid.layers.cast(
dg.to_variable(slf_attn_mask == 0), np.float32)
slf_attn_mask = dg.to_variable(slf_attn_mask)

View File

@ -149,6 +149,7 @@ class Decoder(dg.Layer):
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
else:
mask = layers.expand(mask, [self.num_head, 1, 1])
m_mask, m_self_mask, zero_mask = None, None, None
# Decoder pre-network