diff --git a/examples/fastspeech/data.py b/examples/fastspeech/data.py
index b7d5abe..da1ffec 100644
--- a/examples/fastspeech/data.py
+++ b/examples/fastspeech/data.py
@@ -186,10 +186,4 @@ def batch_examples(batch):
     mels = np.transpose(
         SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1))  #(B,T,num_mels)
 
-    enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32)
-    enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
-    dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels, mels).astype(np.float32)
-    dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
-
-    return (texts, mels, pos_texts, pos_mels, enc_slf_mask, enc_query_mask,
-            dec_slf_mask, dec_query_slf_mask, alignments)
+    return (texts, mels, pos_texts, pos_mels, alignments)
diff --git a/examples/fastspeech/synthesis.py b/examples/fastspeech/synthesis.py
index 781bbcb..6039882 100644
--- a/examples/fastspeech/synthesis.py
+++ b/examples/fastspeech/synthesis.py
@@ -28,7 +28,7 @@ from parakeet.models.fastspeech.fastspeech import FastSpeech
 from parakeet.models.transformer_tts.utils import *
 from parakeet.models.wavenet import WaveNet, UpsampleNet
 from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
-from parakeet.utils.layer_tools import summary, freeze
+from parakeet.utils.layer_tools import freeze
 from parakeet.utils import io
 
 
@@ -82,22 +82,11 @@ def synthesis(text_input, args):
         text = np.expand_dims(text, axis=0)
         pos_text = np.arange(1, text.shape[1] + 1)
         pos_text = np.expand_dims(pos_text, axis=0)
-        enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
-        enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32)
 
         text = dg.to_variable(text)
         pos_text = dg.to_variable(pos_text)
-        enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
-        enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
 
-        _, mel_output_postnet = model(
-            text,
-            pos_text,
-            alpha=args.alpha,
-            enc_non_pad_mask=enc_non_pad_mask,
-            enc_slf_attn_mask=enc_slf_attn_mask,
-            dec_non_pad_mask=None,
-            dec_slf_attn_mask=None)
+        _, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
 
         result = np.exp(mel_output_postnet.numpy())
         mel_output_postnet = fluid.layers.transpose(
@@ -186,7 +175,6 @@ def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place):
         lmd = config["loss"]["lmd"]
         model = Clarinet(upsample_net, teacher, student, stft,
                          student_log_scale_min, lmd)
-        summary(model)
         io.load_parameters(model=model, checkpoint_path=checkpoint)
 
         if not os.path.exists(args.output):
diff --git a/examples/fastspeech/train.py b/examples/fastspeech/train.py
index 21e8ee9..285063f 100644
--- a/examples/fastspeech/train.py
+++ b/examples/fastspeech/train.py
@@ -79,7 +79,9 @@ def main(args):
                                        (cfg['train']['warm_up_step'] *
                                         (cfg['train']['learning_rate']**2)),
                                        cfg['train']['warm_up_step']),
-            parameter_list=model.parameters())
+            parameter_list=model.parameters(),
+            grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
+                'grad_clip_thresh']))
         reader = LJSpeechLoader(
             cfg['audio'],
             place,
@@ -108,9 +110,7 @@ def main(args):
 
             for i, data in enumerate(pbar):
                 pbar.set_description('Processing at epoch %d' % epoch)
-                (character, mel, pos_text, pos_mel, enc_slf_mask,
-                 enc_query_mask, dec_slf_mask, dec_query_slf_mask,
-                 alignment) = data
+                (character, mel, pos_text, pos_mel, alignment) = data
 
                 global_step += 1
 
@@ -119,11 +119,7 @@ def main(args):
                     character,
                     pos_text,
                     mel_pos=pos_mel,
-                    length_target=alignment,
-                    enc_non_pad_mask=enc_query_mask,
-                    enc_slf_attn_mask=enc_slf_mask,
-                    dec_non_pad_mask=dec_query_slf_mask,
-                    dec_slf_attn_mask=dec_slf_mask)
+                    length_target=alignment)
                 mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
                 mel_loss = layers.mse_loss(mel_output, mel)
                 mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
@@ -150,10 +146,7 @@ def main(args):
                     model.apply_collective_grads()
                 else:
                     total_loss.backward()
-                optimizer.minimize(
-                    total_loss,
-                    grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
-                        'train']['grad_clip_thresh']))
+                optimizer.minimize(total_loss)
                 model.clear_gradients()
 
                 # save checkpoint
diff --git a/parakeet/models/fastspeech/decoder.py b/parakeet/models/fastspeech/decoder.py
index 397685d..78dae16 100644
--- a/parakeet/models/fastspeech/decoder.py
+++ b/parakeet/models/fastspeech/decoder.py
@@ -70,7 +70,7 @@ class Decoder(dg.Layer):
         for i, layer in enumerate(self.layer_stack):
             self.add_sublayer('fft_{}'.format(i), layer)
 
-    def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
+    def forward(self, enc_seq, enc_pos):
         """
         Compute decoder outputs.
         
@@ -79,17 +79,26 @@ class Decoder(dg.Layer):
                 the output of length regulator, where T_mel means the timesteps of input spectrum.
             enc_pos (Variable): shape(B, T_mel), dtype int64, 
                 the spectrum position.
-            non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
-            slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64, 
-                the mask of mel spectrum. Defaults to None.
 
         Returns:
             dec_output (Variable): shape(B, T_mel, C), the decoder output.
             dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
         """
         dec_slf_attn_list = []
-        if slf_attn_mask:
-            slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
+        if fluid.framework._dygraph_tracer()._train_mode:
+            slf_attn_mask = get_dec_attn_key_pad_mask(enc_pos, self.n_head,
+                                                      enc_seq.dtype)
+
+        else:
+            len_q = enc_seq.shape[1]
+            slf_attn_mask = layers.triu(
+                layers.ones(
+                    shape=[len_q, len_q], dtype=enc_seq.dtype),
+                diagonal=1)
+            slf_attn_mask = layers.cast(
+                slf_attn_mask != 0, dtype=enc_seq.dtype) * -1e30
+
+        non_pad_mask = get_non_pad_mask(enc_pos, 1, enc_seq.dtype)
 
         # -- Forward
         dec_output = enc_seq + self.position_enc(enc_pos)
diff --git a/parakeet/models/fastspeech/encoder.py b/parakeet/models/fastspeech/encoder.py
index d39fdc1..97ea75e 100644
--- a/parakeet/models/fastspeech/encoder.py
+++ b/parakeet/models/fastspeech/encoder.py
@@ -76,7 +76,7 @@ class Encoder(dg.Layer):
         for i, layer in enumerate(self.layer_stack):
             self.add_sublayer('fft_{}'.format(i), layer)
 
-    def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
+    def forward(self, character, text_pos):
         """
         Encode text sequence.
 
@@ -84,22 +84,21 @@ class Encoder(dg.Layer):
             character (Variable): shape(B, T_text), dtype float32, the input text characters, 
                 where T_text means the timesteps of input characters,
             text_pos (Variable): shape(B, T_text), dtype int64, the input text position. 
-            non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
-            slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, 
-                the mask of input characters. Defaults to None.
         
         Returns:
             enc_output (Variable): shape(B, T_text, C), the encoder output. 
-            non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
             enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
         """
         enc_slf_attn_list = []
-        slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
 
         # -- Forward
         enc_output = self.src_word_emb(character) + self.position_enc(
             text_pos)  #(N, T, C)
 
+        slf_attn_mask = get_attn_key_pad_mask(text_pos, self.n_head,
+                                              enc_output.dtype)
+        non_pad_mask = get_non_pad_mask(text_pos, 1, enc_output.dtype)
+
         for enc_layer in self.layer_stack:
             enc_output, enc_slf_attn = enc_layer(
                 enc_output,
diff --git a/parakeet/models/fastspeech/fastspeech.py b/parakeet/models/fastspeech/fastspeech.py
index 42e9c67..db2fca5 100644
--- a/parakeet/models/fastspeech/fastspeech.py
+++ b/parakeet/models/fastspeech/fastspeech.py
@@ -86,11 +86,7 @@ class FastSpeech(dg.Layer):
     def forward(self,
                 character,
                 text_pos,
-                enc_non_pad_mask,
-                dec_non_pad_mask,
                 mel_pos=None,
-                enc_slf_attn_mask=None,
-                dec_slf_attn_mask=None,
                 length_target=None,
                 alpha=1.0):
         """
@@ -102,12 +98,6 @@ class FastSpeech(dg.Layer):
             text_pos (Variable): shape(B, T_text), dtype int64, the input text position. 
             mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position, 
                 where T_mel means the timesteps of input spectrum,  
-            enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
-            dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
-            enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, 
-                the mask of input characters. Defaults to None.
-            slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
-                the mask of mel spectrum. Defaults to None.
             length_target (Variable, optional): shape(B, T_text), dtype int64, 
                 the duration of phoneme compute from pretrained transformerTTS. Defaults to None. 
             alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence 
@@ -121,19 +111,12 @@ class FastSpeech(dg.Layer):
             dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
         """
 
-        encoder_output, enc_slf_attn_list = self.encoder(
-            character,
-            text_pos,
-            enc_non_pad_mask,
-            slf_attn_mask=enc_slf_attn_mask)
+        encoder_output, enc_slf_attn_list = self.encoder(character, text_pos)
         if fluid.framework._dygraph_tracer()._train_mode:
             length_regulator_output, duration_predictor_output = self.length_regulator(
                 encoder_output, target=length_target, alpha=alpha)
             decoder_output, dec_slf_attn_list = self.decoder(
-                length_regulator_output,
-                mel_pos,
-                dec_non_pad_mask,
-                slf_attn_mask=dec_slf_attn_mask)
+                length_regulator_output, mel_pos)
 
             mel_output = self.mel_linear(decoder_output)
             mel_output_postnet = self.postnet(mel_output) + mel_output
@@ -142,19 +125,8 @@ class FastSpeech(dg.Layer):
         else:
             length_regulator_output, decoder_pos = self.length_regulator(
                 encoder_output, alpha=alpha)
-            slf_attn_mask = get_triu_tensor(
-                decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
-            slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
-            slf_attn_mask = fluid.layers.cast(
-                dg.to_variable(slf_attn_mask != 0), np.float32) * (-2**32 + 1)
-            slf_attn_mask = dg.to_variable(slf_attn_mask)
-            dec_non_pad_mask = fluid.layers.unsqueeze(
-                (decoder_pos != 0).astype(np.float32), [-1])
-            decoder_output, _ = self.decoder(
-                length_regulator_output,
-                decoder_pos,
-                dec_non_pad_mask,
-                slf_attn_mask=slf_attn_mask)
+            decoder_output, _ = self.decoder(length_regulator_output,
+                                             decoder_pos)
             mel_output = self.mel_linear(decoder_output)
             mel_output_postnet = self.postnet(mel_output) + mel_output