add masking functions

This commit is contained in:
iclementine 2020-10-13 15:20:37 +08:00
parent a8192c79cc
commit f9087ea9a2
5 changed files with 92 additions and 43 deletions

View File

@ -207,8 +207,8 @@ class TransformerDecoder(nn.LayerList):
self_attention_weights.append(self_attention_weights_i) self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i) cross_attention_weights.append(cross_attention_weights_i)
return x, self_attention_weights, cross_attention_weights return x, self_attention_weights, cross_attention_weights
class DecoderPreNet(nn.Layer): class DecoderPreNet(nn.Layer):
def __init__(self, d_model, d_hidden, dropout): def __init__(self, d_model, d_hidden, dropout):
self.lin1 = nn.Linear(d_model, d_hidden) self.lin1 = nn.Linear(d_model, d_hidden)
@ -248,8 +248,8 @@ class TransformerTTS(nn.Layer):
self.encoder = TransformerEncoder(d_model, n_heads, d_ffn, encoder_layers, dropout) self.encoder = TransformerEncoder(d_model, n_heads, d_ffn, encoder_layers, dropout)
self.decoder_prenet = DecoderPreNet(d_model, d_prenet, dropout) self.decoder_prenet = DecoderPreNet(d_model, d_prenet, dropout)
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout) self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
self.decoder_postnet = nn.Linear(d_model, reduction_factor * d_mel) self.final_proj = nn.Linear(d_model, reduction_factor * d_mel)
self.postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) self.decoder_postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
def forward(self): def forward(self):
pass pass

View File

@ -0,0 +1,32 @@
import paddle
from paddle.fluid.layers import sequence_mask
def id_mask(input, padding_index=0, dtype="bool"):
return paddle.cast(input != padding_index, dtype)
def feature_mask(input, axis, dtype="bool"):
feature_sum = paddle.sum(paddle.abs(input), axis=axis, keepdim=True)
return paddle.cast(feature_sum != 0, dtype)
def combine_mask(padding_mask, no_future_mask):
"""
Combine the padding mask and no future mask for transformer decoder.
Padding mask is used to mask padding positions and no future mask is used
to prevent the decoder to see future information.
Args:
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask.
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
Returns:
Tensor: shape(batch_size, time_steps, time_steps), combined mask.
"""
# TODO: to support boolean mask by using logical_and?
if padding_mask.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
return paddle.logical_and(padding_mask, no_future_mask)
else:
return padding_mask * no_future_mask
def future_mask(time_steps, dtype="bool"):
mask = paddle.tril(paddle.ones([time_steps, time_steps]))
return paddle.cast(mask, dtype)

View File

@ -4,7 +4,7 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from parakeet.modules import attention as attn from parakeet.modules import attention as attn
from parakeet.modules.masking import combine_mask
class PositionwiseFFN(nn.Layer): class PositionwiseFFN(nn.Layer):
""" """
A faithful implementation of Position-wise Feed-Forward Network A faithful implementation of Position-wise Feed-Forward Network
@ -41,21 +41,6 @@ class PositionwiseFFN(nn.Layer):
""" """
return self.linear2(self.dropout(F.relu(self.linear1(x)))) return self.linear2(self.dropout(F.relu(self.linear1(x))))
def combine_mask(padding_mask, no_future_mask):
"""
Combine the padding mask and no future mask for transformer decoder.
Padding mask is used to mask padding positions and no future mask is used
to prevent the decoder to see future information.
Args:
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask.
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
Returns:
Tensor: shape(batch_size, time_steps, time_steps), combined mask.
"""
# TODO: to support boolean mask by using logical_and?
return paddle.unsqueeze(padding_mask, 1) * no_future_mask
class TransformerEncoderLayer(nn.Layer): class TransformerEncoderLayer(nn.Layer):
""" """
@ -135,7 +120,7 @@ class TransformerDecoderLayer(nn.Layer):
""" """
tq = q.shape[1] tq = q.shape[1]
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq) no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
combined_mask = combine_mask(decoder_mask, no_future_mask) combined_mask = combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask) context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
q = self.layer_norm1(q + context_vector) q = self.layer_norm1(q + context_vector)

54
tests/test_masking.py Normal file
View File

@ -0,0 +1,54 @@
import unittest
import numpy as np
import paddle
paddle.set_default_dtype("float64")
from parakeet.modules import masking
def sequence_mask(lengths, max_length=None, dtype="bool"):
max_length = max_length or np.max(lengths)
ids = np.arange(max_length)
return (ids < np.expand_dims(lengths, -1)).astype(dtype)
def future_mask(lengths, max_length=None, dtype="bool"):
max_length = max_length or np.max(lengths)
return np.tril(np.tril(np.ones(max_length))).astype(dtype)
class TestIDMask(unittest.TestCase):
def test(self):
ids = paddle.to_tensor(
[[1, 2, 3, 0, 0, 0],
[2, 4, 5, 6, 0, 0],
[7, 8, 9, 0, 0, 0]]
)
mask = masking.id_mask(ids)
self.assertTupleEqual(mask.numpy().shape, ids.numpy().shape)
print(mask.numpy())
class TestFeatureMask(unittest.TestCase):
def test(self):
features = np.random.randn(3, 16, 8)
lengths = [16, 14, 12]
for i, length in enumerate(lengths):
features[i, length:, :] = 0
feature_tensor = paddle.to_tensor(features)
mask = masking.feature_mask(feature_tensor, -1)
self.assertTupleEqual(mask.numpy().shape, (3, 16, 1))
print(mask.numpy().squeeze())
class TestCombineMask(unittest.TestCase):
def test_bool_mask(self):
lengths = np.array([12, 8, 9, 10])
padding_mask = sequence_mask(lengths, dtype="bool")
no_future_mask = future_mask(lengths, dtype="bool")
combined_mask1 = np.expand_dims(padding_mask, 1) * no_future_mask
print(paddle.to_tensor(padding_mask).dtype)
print(paddle.to_tensor(no_future_mask).dtype)
combined_mask2 = masking.combine_mask(
paddle.to_tensor(padding_mask).unsqueeze(1), paddle.to_tensor(no_future_mask)
)
np.testing.assert_allclose(combined_mask2.numpy(), combined_mask1)

View File

@ -6,15 +6,6 @@ paddle.disable_static(paddle.CPUPlace())
from parakeet.modules import transformer from parakeet.modules import transformer
def sequence_mask(lengths, max_length=None, dtype="bool"):
max_length = max_length or np.max(lengths)
ids = np.arange(max_length)
return (ids < np.expand_dims(lengths, -1)).astype(dtype)
def future_mask(lengths, max_length=None, dtype="bool"):
max_length = max_length or np.max(lengths)
return np.tril(np.tril(np.ones(max_length)))
class TestPositionwiseFFN(unittest.TestCase): class TestPositionwiseFFN(unittest.TestCase):
def test_io(self): def test_io(self):
net = transformer.PositionwiseFFN(8, 12) net = transformer.PositionwiseFFN(8, 12)
@ -23,19 +14,6 @@ class TestPositionwiseFFN(unittest.TestCase):
self.assertTupleEqual(y.numpy().shape, (2, 3, 4, 8)) self.assertTupleEqual(y.numpy().shape, (2, 3, 4, 8))
class TestCombineMask(unittest.TestCase):
def test_equality(self):
lengths = np.array([12, 8, 9, 10])
padding_mask = sequence_mask(lengths, dtype="float64")
no_future_mask = future_mask(lengths, dtype="float64")
combined_mask1 = np.expand_dims(padding_mask, 1) * no_future_mask
combined_mask2 = transformer.combine_mask(
paddle.to_tensor(padding_mask), paddle.to_tensor(no_future_mask)
)
np.testing.assert_allclose(combined_mask2.numpy(), combined_mask1)
class TestTransformerEncoderLayer(unittest.TestCase): class TestTransformerEncoderLayer(unittest.TestCase):
def test_io(self): def test_io(self):
net = transformer.TransformerEncoderLayer(64, 8, 128, 0.5) net = transformer.TransformerEncoderLayer(64, 8, 128, 0.5)