From 9dad6c3d41cd7e9618b0d0edf1fd8f55103e7eaa Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Fri, 12 Jun 2020 10:13:27 +0000 Subject: [PATCH] fix synthesis for transformerTTS and FastSpeech, use int64 explicitly --- examples/fastspeech/synthesis.py | 4 ++-- examples/transformer_tts/synthesis.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py index de726bd..81b55c5 100644 --- a/examples/fastspeech/synthesis.py +++ b/examples/fastspeech/synthesis.py @@ -83,8 +83,8 @@ def synthesis(text_input, args): pos_text = np.arange(1, text.shape[1] + 1) pos_text = np.expand_dims(pos_text, axis=0) - text = dg.to_variable(text) - pos_text = dg.to_variable(pos_text) + text = dg.to_variable(text).astype(np.int64) + pos_text = dg.to_variable(pos_text).astype(np.int64) _, mel_output_postnet = model(text, pos_text, alpha=args.alpha) diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index 7d7f965..9a1b0e8 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -92,15 +92,17 @@ def synthesis(text_input, args): model_vocoder.eval() # init input text = np.asarray(text_to_sequence(text_input)) - text = fluid.layers.unsqueeze(dg.to_variable(text), [0]) + text = fluid.layers.unsqueeze(dg.to_variable(text).astype(np.int64), [0]) mel_input = dg.to_variable(np.zeros([1, 1, 80])).astype(np.float32) pos_text = np.arange(1, text.shape[1] + 1) - pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0]) + pos_text = fluid.layers.unsqueeze( + dg.to_variable(pos_text).astype(np.int64), [0]) pbar = tqdm(range(args.max_len)) for i in pbar: pos_mel = np.arange(1, mel_input.shape[1] + 1) - pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) + pos_mel = fluid.layers.unsqueeze( + dg.to_variable(pos_mel).astype(np.int64), [0]) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( text, mel_input, pos_text, pos_mel) mel_input = fluid.layers.concat(