diff --git a/examples/transformer_tts/data.py b/examples/transformer_tts/data.py index 410df7e..42be552 100644 --- a/examples/transformer_tts/data.py +++ b/examples/transformer_tts/data.py @@ -141,7 +141,6 @@ def batch_examples(batch): texts = [] mels = [] mel_inputs = [] - mel_lens = [] text_lens = [] pos_texts = [] pos_mels = [] @@ -151,7 +150,6 @@ def batch_examples(batch): np.concatenate( [np.zeros([mel.shape[0], 1], np.float32), mel[:, :-1]], axis=-1)) - mel_lens.append(mel.shape[1]) text_lens.append(len(text)) pos_texts.append(np.arange(1, len(text) + 1)) pos_mels.append(np.arange(1, mel.shape[1] + 1)) @@ -174,11 +172,6 @@ def batch_examples(batch): for i, _ in sorted( zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True) ] - mel_lens = [ - i - for i, _ in sorted( - zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) - ] pos_texts = [ i for i, _ in sorted( @@ -200,18 +193,7 @@ def batch_examples(batch): mel_inputs = np.transpose( SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels) - enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32) - enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32) - dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels, - mel_inputs).astype(np.float32) - enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0]).astype( - np.float32) - dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32) - dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32) - - return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens), - np.array(mel_lens), enc_slf_mask, enc_query_mask, dec_slf_mask, - enc_dec_mask, dec_query_slf_mask, dec_query_mask) + return (texts, mels, mel_inputs, pos_texts, pos_mels) def batch_examples_vocoder(batch): diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index 27d3340..9d2b012 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -101,15 +101,10 @@ def synthesis(text_input, args): pbar = tqdm(range(args.max_len)) for i in pbar: - dec_slf_mask = get_triu_tensor( - mel_input.numpy(), mel_input.numpy()).astype(np.float32) - dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0) - dec_slf_mask = fluid.layers.cast( - dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1) pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( - text, mel_input, pos_text, pos_mel, dec_slf_mask) + text, mel_input, pos_text, pos_mel) mel_input = fluid.layers.concat( [mel_input, postnet_pred[:, -1:, :]], axis=1) diff --git a/examples/transformer_tts/train_transformer.py b/examples/transformer_tts/train_transformer.py index a954c79..0bac0a7 100644 --- a/examples/transformer_tts/train_transformer.py +++ b/examples/transformer_tts/train_transformer.py @@ -79,7 +79,9 @@ def main(args): (cfg['train']['warm_up_step'] * (cfg['train']['learning_rate']**2)), cfg['train']['warm_up_step']), - parameter_list=model.parameters()) + parameter_list=model.parameters(), + grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][ + 'grad_clip_thresh'])) # Load parameters. global_step = io.load_parameters( @@ -107,21 +109,12 @@ def main(args): pbar = tqdm(reader) for i, data in enumerate(pbar): pbar.set_description('Processing at epoch %d' % epoch) - character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data + character, mel, mel_input, pos_text, pos_mel = data global_step += 1 mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( - character, - mel_input, - pos_text, - pos_mel, - dec_slf_mask=dec_slf_mask, - enc_slf_mask=enc_slf_mask, - enc_query_mask=enc_query_mask, - enc_dec_mask=enc_dec_mask, - dec_query_slf_mask=dec_query_slf_mask, - dec_query_mask=dec_query_mask) + character, mel_input, pos_text, pos_mel) mel_loss = layers.mean( layers.abs(layers.elementwise_sub(mel_pred, mel))) @@ -202,10 +195,7 @@ def main(args): model.apply_collective_grads() else: loss.backward() - optimizer.minimize( - loss, - grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[ - 'train']['grad_clip_thresh'])) + optimizer.minimize(loss) model.clear_gradients() # save checkpoint diff --git a/examples/transformer_tts/train_vocoder.py b/examples/transformer_tts/train_vocoder.py index 20e39e4..ad74149 100644 --- a/examples/transformer_tts/train_vocoder.py +++ b/examples/transformer_tts/train_vocoder.py @@ -74,7 +74,9 @@ def main(args): (cfg['train']['warm_up_step'] * (cfg['train']['learning_rate']**2)), cfg['train']['warm_up_step']), - parameter_list=model.parameters()) + parameter_list=model.parameters(), + grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][ + 'grad_clip_thresh'])) # Load parameters. global_step = io.load_parameters( @@ -117,10 +119,7 @@ def main(args): model.apply_collective_grads() else: loss.backward() - optimizer.minimize( - loss, - grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[ - 'train']['grad_clip_thresh'])) + optimizer.minimize(loss) model.clear_gradients() if local_rank == 0: diff --git a/parakeet/models/transformer_tts/decoder.py b/parakeet/models/transformer_tts/decoder.py index d33196f..41e11a0 100644 --- a/parakeet/models/transformer_tts/decoder.py +++ b/parakeet/models/transformer_tts/decoder.py @@ -113,15 +113,7 @@ class Decoder(dg.Layer): outputs_per_step=outputs_per_step, use_cudnn=True) - def forward(self, - key, - value, - query, - positional, - mask, - m_mask=None, - m_self_mask=None, - zero_mask=None): + def forward(self, key, value, query, positional, c_mask): """ Compute decoder outputs. @@ -132,11 +124,7 @@ class Decoder(dg.Layer): query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, where T_mel means the timesteps of input spectrum, positional (Variable): shape(B, T_mel), dtype int64, the spectrum position. - mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. - m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None. - m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None. - zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None. - + c_mask (Variable): shape(B, T_text, 1), dtype float32, query mask returned from encoder. Returns: mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. out (Variable): shape(B, T_mel, C), the decoder output after post mel network. @@ -148,15 +136,20 @@ class Decoder(dg.Layer): # get decoder mask with triangular matrix if fluid.framework._dygraph_tracer()._train_mode: - m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]]) - m_self_mask = layers.expand(m_self_mask, - [self.num_head, 1, query.shape[1]]) - mask = layers.expand(mask, [self.num_head, 1, 1]) - zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1]) + mask = get_dec_attn_key_pad_mask(positional, self.num_head, + query.dtype) + m_mask = get_non_pad_mask(positional, self.num_head, query.dtype) + zero_mask = layers.cast(c_mask == 0, dtype=query.dtype) * -1e30 + zero_mask = layers.transpose(zero_mask, perm=[0, 2, 1]) else: - mask = layers.expand(mask, [self.num_head, 1, 1]) - m_mask, m_self_mask, zero_mask = None, None, None + len_q = query.shape[1] + mask = layers.triu( + layers.ones( + shape=[len_q, len_q], dtype=query.dtype), + diagonal=1) + mask = layers.cast(mask != 0, dtype=query.dtype) * -1e30 + m_mask, zero_mask = None, None # Decoder pre-network query = self.decoder_prenet(query) @@ -179,7 +172,7 @@ class Decoder(dg.Layer): for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers, self.ffns): query, attn_dec = selfattn( - query, query, query, mask=mask, query_mask=m_self_mask) + query, query, query, mask=mask, query_mask=m_mask) query, attn_dot = attn( key, value, query, mask=zero_mask, query_mask=m_mask) query = ffn(query) diff --git a/parakeet/models/transformer_tts/encoder.py b/parakeet/models/transformer_tts/encoder.py index 6b1625e..a7a0f7a 100644 --- a/parakeet/models/transformer_tts/encoder.py +++ b/parakeet/models/transformer_tts/encoder.py @@ -64,7 +64,7 @@ class Encoder(dg.Layer): for i, layer in enumerate(self.ffns): self.add_sublayer("ffns_{}".format(i), layer) - def forward(self, x, positional, mask=None, query_mask=None): + def forward(self, x, positional): """ Encode text sequence. @@ -72,24 +72,22 @@ class Encoder(dg.Layer): x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text, positional (Variable): shape(B, T_text), dtype int64, the characters position. - mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None. - query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None. Returns: x (Variable): shape(B, T_text, C), the encoder output. attentions (list[Variable]): len(n_layers), the encoder self attention list. """ - if fluid.framework._dygraph_tracer()._train_mode: - seq_len_key = x.shape[1] - query_mask = layers.expand(query_mask, - [self.num_head, 1, seq_len_key]) - mask = layers.expand(mask, [self.num_head, 1, 1]) - else: - query_mask, mask = None, None # Encoder pre_network x = self.encoder_prenet(x) + if fluid.framework._dygraph_tracer()._train_mode: + mask = get_attn_key_pad_mask(positional, self.num_head, x.dtype) + query_mask = get_non_pad_mask(positional, self.num_head, x.dtype) + + else: + query_mask, mask = None, None + # Get positional encoding positional = self.pos_emb(positional) @@ -105,4 +103,4 @@ class Encoder(dg.Layer): x = ffn(x) attentions.append(attention) - return x, attentions + return x, attentions, query_mask diff --git a/parakeet/models/transformer_tts/transformer_tts.py b/parakeet/models/transformer_tts/transformer_tts.py index 3dcde18..e1d9418 100644 --- a/parakeet/models/transformer_tts/transformer_tts.py +++ b/parakeet/models/transformer_tts/transformer_tts.py @@ -45,17 +45,7 @@ class TransformerTTS(dg.Layer): self.decoder = Decoder(num_hidden, n_mels, outputs_per_step, decoder_num_head, decoder_n_layers) - def forward(self, - characters, - mel_input, - pos_text, - pos_mel, - dec_slf_mask, - enc_slf_mask=None, - enc_query_mask=None, - enc_dec_mask=None, - dec_query_slf_mask=None, - dec_query_mask=None): + def forward(self, characters, mel_input, pos_text, pos_mel): """ TransformerTTS network. @@ -65,13 +55,6 @@ class TransformerTTS(dg.Layer): mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, where T_mel means the timesteps of input spectrum, pos_text (Variable): shape(B, T_text), dtype int64, the characters position. - dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position. - mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. - enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None. - enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None. - dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None. - dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None. - enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None. Returns: mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. @@ -81,16 +64,8 @@ class TransformerTTS(dg.Layer): attns_enc (list[Variable]): len(n_layers), the encoder self attention list. attns_dec (list[Variable]): len(n_layers), the decoder self attention list. """ - key, attns_enc = self.encoder( - characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) + key, attns_enc, query_mask = self.encoder(characters, pos_text) mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder( - key, - key, - mel_input, - pos_mel, - mask=dec_slf_mask, - zero_mask=enc_dec_mask, - m_self_mask=dec_query_slf_mask, - m_mask=dec_query_mask) + key, key, mel_input, pos_mel, query_mask) return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec diff --git a/parakeet/models/transformer_tts/utils.py b/parakeet/models/transformer_tts/utils.py index 30c42df..3fa4c63 100644 --- a/parakeet/models/transformer_tts/utils.py +++ b/parakeet/models/transformer_tts/utils.py @@ -50,41 +50,37 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): return sinusoid_table -def get_non_pad_mask(seq): - mask = (seq != 0).astype(np.float32) - mask = np.expand_dims(mask, axis=-1) +def get_non_pad_mask(seq, num_head, dtype): + mask = layers.cast(seq != 0, dtype=dtype) + mask = layers.unsqueeze(mask, axes=[-1]) + mask = layers.expand(mask, [num_head, 1, 1]) return mask -def get_attn_key_pad_mask(seq_k): +def get_attn_key_pad_mask(seq_k, num_head, dtype): ''' For masking out the padding part of key sequence. ''' # Expand to fit the shape of key query attention matrix. - padding_mask = (seq_k != 0).astype(np.float32) - padding_mask = np.expand_dims(padding_mask, axis=1) - padding_mask = ( - padding_mask == 0).astype(np.float32) * -1e30 #* (-2**32 + 1) + padding_mask = layers.cast(seq_k == 0, dtype=dtype) * -1e30 + padding_mask = layers.unsqueeze(padding_mask, axes=[1]) + padding_mask = layers.expand(padding_mask, [num_head, 1, 1]) return padding_mask -def get_dec_attn_key_pad_mask(seq_k, seq_q): +def get_dec_attn_key_pad_mask(seq_k, num_head, dtype): ''' For masking out the padding part of key sequence. ''' # Expand to fit the shape of key query attention matrix. - padding_mask = (seq_k == 0).astype(np.float32) - padding_mask = np.expand_dims(padding_mask, axis=1) - triu_tensor = get_triu_tensor(seq_q, seq_q) - padding_mask = padding_mask + triu_tensor - padding_mask = ( - padding_mask != 0).astype(np.float32) * -1e30 #* (-2**32 + 1) - return padding_mask - - -def get_triu_tensor(seq_k, seq_q): - ''' For make a triu tensor ''' + padding_mask = layers.cast(seq_k == 0, dtype=dtype) + padding_mask = layers.unsqueeze(padding_mask, axes=[1]) len_k = seq_k.shape[1] - len_q = seq_q.shape[1] - triu_tensor = np.triu(np.ones([len_k, len_q]), 1) - return triu_tensor + triu = layers.triu( + layers.ones( + shape=[len_k, len_k], dtype=dtype), diagonal=1) + padding_mask = padding_mask + triu + padding_mask = layers.cast( + padding_mask != 0, dtype=dtype) * -1e30 #* (-2**32 + 1) + padding_mask = layers.expand(padding_mask, [num_head, 1, 1]) + return padding_mask def guided_attention(N, T, g=0.2):