modified transformer_tts to make sure it works on paddle 1.8
This commit is contained in:
parent
45c07fa42e
commit
46e254d1f8
|
@ -141,7 +141,6 @@ def batch_examples(batch):
|
||||||
texts = []
|
texts = []
|
||||||
mels = []
|
mels = []
|
||||||
mel_inputs = []
|
mel_inputs = []
|
||||||
mel_lens = []
|
|
||||||
text_lens = []
|
text_lens = []
|
||||||
pos_texts = []
|
pos_texts = []
|
||||||
pos_mels = []
|
pos_mels = []
|
||||||
|
@ -151,7 +150,6 @@ def batch_examples(batch):
|
||||||
np.concatenate(
|
np.concatenate(
|
||||||
[np.zeros([mel.shape[0], 1], np.float32), mel[:, :-1]],
|
[np.zeros([mel.shape[0], 1], np.float32), mel[:, :-1]],
|
||||||
axis=-1))
|
axis=-1))
|
||||||
mel_lens.append(mel.shape[1])
|
|
||||||
text_lens.append(len(text))
|
text_lens.append(len(text))
|
||||||
pos_texts.append(np.arange(1, len(text) + 1))
|
pos_texts.append(np.arange(1, len(text) + 1))
|
||||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
||||||
|
@ -174,11 +172,6 @@ def batch_examples(batch):
|
||||||
for i, _ in sorted(
|
for i, _ in sorted(
|
||||||
zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True)
|
zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True)
|
||||||
]
|
]
|
||||||
mel_lens = [
|
|
||||||
i
|
|
||||||
for i, _ in sorted(
|
|
||||||
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
|
||||||
]
|
|
||||||
pos_texts = [
|
pos_texts = [
|
||||||
i
|
i
|
||||||
for i, _ in sorted(
|
for i, _ in sorted(
|
||||||
|
@ -200,18 +193,7 @@ def batch_examples(batch):
|
||||||
mel_inputs = np.transpose(
|
mel_inputs = np.transpose(
|
||||||
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
|
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||||
|
|
||||||
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32)
|
return (texts, mels, mel_inputs, pos_texts, pos_mels)
|
||||||
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
|
|
||||||
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,
|
|
||||||
mel_inputs).astype(np.float32)
|
|
||||||
enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0]).astype(
|
|
||||||
np.float32)
|
|
||||||
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
|
||||||
dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
|
||||||
|
|
||||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens),
|
|
||||||
np.array(mel_lens), enc_slf_mask, enc_query_mask, dec_slf_mask,
|
|
||||||
enc_dec_mask, dec_query_slf_mask, dec_query_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_examples_vocoder(batch):
|
def batch_examples_vocoder(batch):
|
||||||
|
|
|
@ -101,15 +101,10 @@ def synthesis(text_input, args):
|
||||||
|
|
||||||
pbar = tqdm(range(args.max_len))
|
pbar = tqdm(range(args.max_len))
|
||||||
for i in pbar:
|
for i in pbar:
|
||||||
dec_slf_mask = get_triu_tensor(
|
|
||||||
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
|
||||||
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
|
|
||||||
dec_slf_mask = fluid.layers.cast(
|
|
||||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
|
||||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
text, mel_input, pos_text, pos_mel)
|
||||||
mel_input = fluid.layers.concat(
|
mel_input = fluid.layers.concat(
|
||||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,9 @@ def main(args):
|
||||||
(cfg['train']['warm_up_step'] *
|
(cfg['train']['warm_up_step'] *
|
||||||
(cfg['train']['learning_rate']**2)),
|
(cfg['train']['learning_rate']**2)),
|
||||||
cfg['train']['warm_up_step']),
|
cfg['train']['warm_up_step']),
|
||||||
parameter_list=model.parameters())
|
parameter_list=model.parameters(),
|
||||||
|
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
|
||||||
|
'grad_clip_thresh']))
|
||||||
|
|
||||||
# Load parameters.
|
# Load parameters.
|
||||||
global_step = io.load_parameters(
|
global_step = io.load_parameters(
|
||||||
|
@ -107,21 +109,12 @@ def main(args):
|
||||||
pbar = tqdm(reader)
|
pbar = tqdm(reader)
|
||||||
for i, data in enumerate(pbar):
|
for i, data in enumerate(pbar):
|
||||||
pbar.set_description('Processing at epoch %d' % epoch)
|
pbar.set_description('Processing at epoch %d' % epoch)
|
||||||
character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data
|
character, mel, mel_input, pos_text, pos_mel = data
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
character,
|
character, mel_input, pos_text, pos_mel)
|
||||||
mel_input,
|
|
||||||
pos_text,
|
|
||||||
pos_mel,
|
|
||||||
dec_slf_mask=dec_slf_mask,
|
|
||||||
enc_slf_mask=enc_slf_mask,
|
|
||||||
enc_query_mask=enc_query_mask,
|
|
||||||
enc_dec_mask=enc_dec_mask,
|
|
||||||
dec_query_slf_mask=dec_query_slf_mask,
|
|
||||||
dec_query_mask=dec_query_mask)
|
|
||||||
|
|
||||||
mel_loss = layers.mean(
|
mel_loss = layers.mean(
|
||||||
layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
||||||
|
@ -202,10 +195,7 @@ def main(args):
|
||||||
model.apply_collective_grads()
|
model.apply_collective_grads()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.minimize(
|
optimizer.minimize(loss)
|
||||||
loss,
|
|
||||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
|
||||||
'train']['grad_clip_thresh']))
|
|
||||||
model.clear_gradients()
|
model.clear_gradients()
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
|
|
|
@ -74,7 +74,9 @@ def main(args):
|
||||||
(cfg['train']['warm_up_step'] *
|
(cfg['train']['warm_up_step'] *
|
||||||
(cfg['train']['learning_rate']**2)),
|
(cfg['train']['learning_rate']**2)),
|
||||||
cfg['train']['warm_up_step']),
|
cfg['train']['warm_up_step']),
|
||||||
parameter_list=model.parameters())
|
parameter_list=model.parameters(),
|
||||||
|
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
|
||||||
|
'grad_clip_thresh']))
|
||||||
|
|
||||||
# Load parameters.
|
# Load parameters.
|
||||||
global_step = io.load_parameters(
|
global_step = io.load_parameters(
|
||||||
|
@ -117,10 +119,7 @@ def main(args):
|
||||||
model.apply_collective_grads()
|
model.apply_collective_grads()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.minimize(
|
optimizer.minimize(loss)
|
||||||
loss,
|
|
||||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
|
||||||
'train']['grad_clip_thresh']))
|
|
||||||
model.clear_gradients()
|
model.clear_gradients()
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
|
|
|
@ -113,15 +113,7 @@ class Decoder(dg.Layer):
|
||||||
outputs_per_step=outputs_per_step,
|
outputs_per_step=outputs_per_step,
|
||||||
use_cudnn=True)
|
use_cudnn=True)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, key, value, query, positional, c_mask):
|
||||||
key,
|
|
||||||
value,
|
|
||||||
query,
|
|
||||||
positional,
|
|
||||||
mask,
|
|
||||||
m_mask=None,
|
|
||||||
m_self_mask=None,
|
|
||||||
zero_mask=None):
|
|
||||||
"""
|
"""
|
||||||
Compute decoder outputs.
|
Compute decoder outputs.
|
||||||
|
|
||||||
|
@ -132,11 +124,7 @@ class Decoder(dg.Layer):
|
||||||
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
||||||
where T_mel means the timesteps of input spectrum,
|
where T_mel means the timesteps of input spectrum,
|
||||||
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
||||||
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
|
c_mask (Variable): shape(B, T_text, 1), dtype float32, query mask returned from encoder.
|
||||||
m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
|
|
||||||
m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
|
|
||||||
zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
||||||
out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
|
out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
|
||||||
|
@ -148,15 +136,20 @@ class Decoder(dg.Layer):
|
||||||
# get decoder mask with triangular matrix
|
# get decoder mask with triangular matrix
|
||||||
|
|
||||||
if fluid.framework._dygraph_tracer()._train_mode:
|
if fluid.framework._dygraph_tracer()._train_mode:
|
||||||
m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]])
|
mask = get_dec_attn_key_pad_mask(positional, self.num_head,
|
||||||
m_self_mask = layers.expand(m_self_mask,
|
query.dtype)
|
||||||
[self.num_head, 1, query.shape[1]])
|
m_mask = get_non_pad_mask(positional, self.num_head, query.dtype)
|
||||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
zero_mask = layers.cast(c_mask == 0, dtype=query.dtype) * -1e30
|
||||||
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
|
zero_mask = layers.transpose(zero_mask, perm=[0, 2, 1])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
len_q = query.shape[1]
|
||||||
m_mask, m_self_mask, zero_mask = None, None, None
|
mask = layers.triu(
|
||||||
|
layers.ones(
|
||||||
|
shape=[len_q, len_q], dtype=query.dtype),
|
||||||
|
diagonal=1)
|
||||||
|
mask = layers.cast(mask != 0, dtype=query.dtype) * -1e30
|
||||||
|
m_mask, zero_mask = None, None
|
||||||
|
|
||||||
# Decoder pre-network
|
# Decoder pre-network
|
||||||
query = self.decoder_prenet(query)
|
query = self.decoder_prenet(query)
|
||||||
|
@ -179,7 +172,7 @@ class Decoder(dg.Layer):
|
||||||
for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers,
|
for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers,
|
||||||
self.ffns):
|
self.ffns):
|
||||||
query, attn_dec = selfattn(
|
query, attn_dec = selfattn(
|
||||||
query, query, query, mask=mask, query_mask=m_self_mask)
|
query, query, query, mask=mask, query_mask=m_mask)
|
||||||
query, attn_dot = attn(
|
query, attn_dot = attn(
|
||||||
key, value, query, mask=zero_mask, query_mask=m_mask)
|
key, value, query, mask=zero_mask, query_mask=m_mask)
|
||||||
query = ffn(query)
|
query = ffn(query)
|
||||||
|
|
|
@ -64,7 +64,7 @@ class Encoder(dg.Layer):
|
||||||
for i, layer in enumerate(self.ffns):
|
for i, layer in enumerate(self.ffns):
|
||||||
self.add_sublayer("ffns_{}".format(i), layer)
|
self.add_sublayer("ffns_{}".format(i), layer)
|
||||||
|
|
||||||
def forward(self, x, positional, mask=None, query_mask=None):
|
def forward(self, x, positional):
|
||||||
"""
|
"""
|
||||||
Encode text sequence.
|
Encode text sequence.
|
||||||
|
|
||||||
|
@ -72,24 +72,22 @@ class Encoder(dg.Layer):
|
||||||
x (Variable): shape(B, T_text), dtype float32, the input character,
|
x (Variable): shape(B, T_text), dtype float32, the input character,
|
||||||
where T_text means the timesteps of input text,
|
where T_text means the timesteps of input text,
|
||||||
positional (Variable): shape(B, T_text), dtype int64, the characters position.
|
positional (Variable): shape(B, T_text), dtype int64, the characters position.
|
||||||
mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
|
|
||||||
query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
x (Variable): shape(B, T_text, C), the encoder output.
|
x (Variable): shape(B, T_text, C), the encoder output.
|
||||||
attentions (list[Variable]): len(n_layers), the encoder self attention list.
|
attentions (list[Variable]): len(n_layers), the encoder self attention list.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if fluid.framework._dygraph_tracer()._train_mode:
|
|
||||||
seq_len_key = x.shape[1]
|
|
||||||
query_mask = layers.expand(query_mask,
|
|
||||||
[self.num_head, 1, seq_len_key])
|
|
||||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
|
||||||
else:
|
|
||||||
query_mask, mask = None, None
|
|
||||||
# Encoder pre_network
|
# Encoder pre_network
|
||||||
x = self.encoder_prenet(x)
|
x = self.encoder_prenet(x)
|
||||||
|
|
||||||
|
if fluid.framework._dygraph_tracer()._train_mode:
|
||||||
|
mask = get_attn_key_pad_mask(positional, self.num_head, x.dtype)
|
||||||
|
query_mask = get_non_pad_mask(positional, self.num_head, x.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_mask, mask = None, None
|
||||||
|
|
||||||
# Get positional encoding
|
# Get positional encoding
|
||||||
positional = self.pos_emb(positional)
|
positional = self.pos_emb(positional)
|
||||||
|
|
||||||
|
@ -105,4 +103,4 @@ class Encoder(dg.Layer):
|
||||||
x = ffn(x)
|
x = ffn(x)
|
||||||
attentions.append(attention)
|
attentions.append(attention)
|
||||||
|
|
||||||
return x, attentions
|
return x, attentions, query_mask
|
||||||
|
|
|
@ -45,17 +45,7 @@ class TransformerTTS(dg.Layer):
|
||||||
self.decoder = Decoder(num_hidden, n_mels, outputs_per_step,
|
self.decoder = Decoder(num_hidden, n_mels, outputs_per_step,
|
||||||
decoder_num_head, decoder_n_layers)
|
decoder_num_head, decoder_n_layers)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self, characters, mel_input, pos_text, pos_mel):
|
||||||
characters,
|
|
||||||
mel_input,
|
|
||||||
pos_text,
|
|
||||||
pos_mel,
|
|
||||||
dec_slf_mask,
|
|
||||||
enc_slf_mask=None,
|
|
||||||
enc_query_mask=None,
|
|
||||||
enc_dec_mask=None,
|
|
||||||
dec_query_slf_mask=None,
|
|
||||||
dec_query_mask=None):
|
|
||||||
"""
|
"""
|
||||||
TransformerTTS network.
|
TransformerTTS network.
|
||||||
|
|
||||||
|
@ -65,13 +55,6 @@ class TransformerTTS(dg.Layer):
|
||||||
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
||||||
where T_mel means the timesteps of input spectrum,
|
where T_mel means the timesteps of input spectrum,
|
||||||
pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
|
pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
|
||||||
dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
|
||||||
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
|
|
||||||
enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
|
|
||||||
enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
|
|
||||||
dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
|
|
||||||
dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
|
|
||||||
enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
||||||
|
@ -81,16 +64,8 @@ class TransformerTTS(dg.Layer):
|
||||||
attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
|
attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
|
||||||
attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
|
attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
|
||||||
"""
|
"""
|
||||||
key, attns_enc = self.encoder(
|
key, attns_enc, query_mask = self.encoder(characters, pos_text)
|
||||||
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
|
||||||
|
|
||||||
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(
|
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(
|
||||||
key,
|
key, key, mel_input, pos_mel, query_mask)
|
||||||
key,
|
|
||||||
mel_input,
|
|
||||||
pos_mel,
|
|
||||||
mask=dec_slf_mask,
|
|
||||||
zero_mask=enc_dec_mask,
|
|
||||||
m_self_mask=dec_query_slf_mask,
|
|
||||||
m_mask=dec_query_mask)
|
|
||||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
||||||
|
|
|
@ -50,41 +50,37 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
||||||
return sinusoid_table
|
return sinusoid_table
|
||||||
|
|
||||||
|
|
||||||
def get_non_pad_mask(seq):
|
def get_non_pad_mask(seq, num_head, dtype):
|
||||||
mask = (seq != 0).astype(np.float32)
|
mask = layers.cast(seq != 0, dtype=dtype)
|
||||||
mask = np.expand_dims(mask, axis=-1)
|
mask = layers.unsqueeze(mask, axes=[-1])
|
||||||
|
mask = layers.expand(mask, [num_head, 1, 1])
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def get_attn_key_pad_mask(seq_k):
|
def get_attn_key_pad_mask(seq_k, num_head, dtype):
|
||||||
''' For masking out the padding part of key sequence. '''
|
''' For masking out the padding part of key sequence. '''
|
||||||
# Expand to fit the shape of key query attention matrix.
|
# Expand to fit the shape of key query attention matrix.
|
||||||
padding_mask = (seq_k != 0).astype(np.float32)
|
padding_mask = layers.cast(seq_k == 0, dtype=dtype) * -1e30
|
||||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
padding_mask = layers.unsqueeze(padding_mask, axes=[1])
|
||||||
padding_mask = (
|
padding_mask = layers.expand(padding_mask, [num_head, 1, 1])
|
||||||
padding_mask == 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
|
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
|
|
||||||
def get_dec_attn_key_pad_mask(seq_k, seq_q):
|
def get_dec_attn_key_pad_mask(seq_k, num_head, dtype):
|
||||||
''' For masking out the padding part of key sequence. '''
|
''' For masking out the padding part of key sequence. '''
|
||||||
|
|
||||||
# Expand to fit the shape of key query attention matrix.
|
# Expand to fit the shape of key query attention matrix.
|
||||||
padding_mask = (seq_k == 0).astype(np.float32)
|
padding_mask = layers.cast(seq_k == 0, dtype=dtype)
|
||||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
padding_mask = layers.unsqueeze(padding_mask, axes=[1])
|
||||||
triu_tensor = get_triu_tensor(seq_q, seq_q)
|
|
||||||
padding_mask = padding_mask + triu_tensor
|
|
||||||
padding_mask = (
|
|
||||||
padding_mask != 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
|
|
||||||
return padding_mask
|
|
||||||
|
|
||||||
|
|
||||||
def get_triu_tensor(seq_k, seq_q):
|
|
||||||
''' For make a triu tensor '''
|
|
||||||
len_k = seq_k.shape[1]
|
len_k = seq_k.shape[1]
|
||||||
len_q = seq_q.shape[1]
|
triu = layers.triu(
|
||||||
triu_tensor = np.triu(np.ones([len_k, len_q]), 1)
|
layers.ones(
|
||||||
return triu_tensor
|
shape=[len_k, len_k], dtype=dtype), diagonal=1)
|
||||||
|
padding_mask = padding_mask + triu
|
||||||
|
padding_mask = layers.cast(
|
||||||
|
padding_mask != 0, dtype=dtype) * -1e30 #* (-2**32 + 1)
|
||||||
|
padding_mask = layers.expand(padding_mask, [num_head, 1, 1])
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
|
||||||
def guided_attention(N, T, g=0.2):
|
def guided_attention(N, T, g=0.2):
|
||||||
|
|
Loading…
Reference in New Issue