add masking functions
This commit is contained in:
parent
a8192c79cc
commit
f9087ea9a2
|
@ -248,8 +248,8 @@ class TransformerTTS(nn.Layer):
|
|||
self.encoder = TransformerEncoder(d_model, n_heads, d_ffn, encoder_layers, dropout)
|
||||
self.decoder_prenet = DecoderPreNet(d_model, d_prenet, dropout)
|
||||
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
|
||||
self.decoder_postnet = nn.Linear(d_model, reduction_factor * d_mel)
|
||||
self.postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
self.final_proj = nn.Linear(d_model, reduction_factor * d_mel)
|
||||
self.decoder_postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import paddle
|
||||
from paddle.fluid.layers import sequence_mask
|
||||
|
||||
def id_mask(input, padding_index=0, dtype="bool"):
|
||||
return paddle.cast(input != padding_index, dtype)
|
||||
|
||||
def feature_mask(input, axis, dtype="bool"):
|
||||
feature_sum = paddle.sum(paddle.abs(input), axis=axis, keepdim=True)
|
||||
return paddle.cast(feature_sum != 0, dtype)
|
||||
|
||||
def combine_mask(padding_mask, no_future_mask):
|
||||
"""
|
||||
Combine the padding mask and no future mask for transformer decoder.
|
||||
Padding mask is used to mask padding positions and no future mask is used
|
||||
to prevent the decoder to see future information.
|
||||
|
||||
Args:
|
||||
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask.
|
||||
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, time_steps, time_steps), combined mask.
|
||||
"""
|
||||
# TODO: to support boolean mask by using logical_and?
|
||||
if padding_mask.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
|
||||
return paddle.logical_and(padding_mask, no_future_mask)
|
||||
else:
|
||||
return padding_mask * no_future_mask
|
||||
|
||||
def future_mask(time_steps, dtype="bool"):
|
||||
mask = paddle.tril(paddle.ones([time_steps, time_steps]))
|
||||
return paddle.cast(mask, dtype)
|
|
@ -4,7 +4,7 @@ from paddle import nn
|
|||
from paddle.nn import functional as F
|
||||
|
||||
from parakeet.modules import attention as attn
|
||||
|
||||
from parakeet.modules.masking import combine_mask
|
||||
class PositionwiseFFN(nn.Layer):
|
||||
"""
|
||||
A faithful implementation of Position-wise Feed-Forward Network
|
||||
|
@ -41,21 +41,6 @@ class PositionwiseFFN(nn.Layer):
|
|||
"""
|
||||
return self.linear2(self.dropout(F.relu(self.linear1(x))))
|
||||
|
||||
def combine_mask(padding_mask, no_future_mask):
|
||||
"""
|
||||
Combine the padding mask and no future mask for transformer decoder.
|
||||
Padding mask is used to mask padding positions and no future mask is used
|
||||
to prevent the decoder to see future information.
|
||||
|
||||
Args:
|
||||
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask.
|
||||
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, time_steps, time_steps), combined mask.
|
||||
"""
|
||||
# TODO: to support boolean mask by using logical_and?
|
||||
return paddle.unsqueeze(padding_mask, 1) * no_future_mask
|
||||
|
||||
class TransformerEncoderLayer(nn.Layer):
|
||||
"""
|
||||
|
@ -135,7 +120,7 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
"""
|
||||
tq = q.shape[1]
|
||||
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
|
||||
combined_mask = combine_mask(decoder_mask, no_future_mask)
|
||||
combined_mask = combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
|
||||
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
|
||||
q = self.layer_norm1(q + context_vector)
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
paddle.set_default_dtype("float64")
|
||||
|
||||
from parakeet.modules import masking
|
||||
|
||||
|
||||
def sequence_mask(lengths, max_length=None, dtype="bool"):
|
||||
max_length = max_length or np.max(lengths)
|
||||
ids = np.arange(max_length)
|
||||
return (ids < np.expand_dims(lengths, -1)).astype(dtype)
|
||||
|
||||
def future_mask(lengths, max_length=None, dtype="bool"):
|
||||
max_length = max_length or np.max(lengths)
|
||||
return np.tril(np.tril(np.ones(max_length))).astype(dtype)
|
||||
|
||||
class TestIDMask(unittest.TestCase):
|
||||
def test(self):
|
||||
ids = paddle.to_tensor(
|
||||
[[1, 2, 3, 0, 0, 0],
|
||||
[2, 4, 5, 6, 0, 0],
|
||||
[7, 8, 9, 0, 0, 0]]
|
||||
)
|
||||
mask = masking.id_mask(ids)
|
||||
self.assertTupleEqual(mask.numpy().shape, ids.numpy().shape)
|
||||
print(mask.numpy())
|
||||
|
||||
class TestFeatureMask(unittest.TestCase):
|
||||
def test(self):
|
||||
features = np.random.randn(3, 16, 8)
|
||||
lengths = [16, 14, 12]
|
||||
for i, length in enumerate(lengths):
|
||||
features[i, length:, :] = 0
|
||||
|
||||
feature_tensor = paddle.to_tensor(features)
|
||||
mask = masking.feature_mask(feature_tensor, -1)
|
||||
self.assertTupleEqual(mask.numpy().shape, (3, 16, 1))
|
||||
print(mask.numpy().squeeze())
|
||||
|
||||
|
||||
class TestCombineMask(unittest.TestCase):
|
||||
def test_bool_mask(self):
|
||||
lengths = np.array([12, 8, 9, 10])
|
||||
padding_mask = sequence_mask(lengths, dtype="bool")
|
||||
no_future_mask = future_mask(lengths, dtype="bool")
|
||||
combined_mask1 = np.expand_dims(padding_mask, 1) * no_future_mask
|
||||
|
||||
print(paddle.to_tensor(padding_mask).dtype)
|
||||
print(paddle.to_tensor(no_future_mask).dtype)
|
||||
combined_mask2 = masking.combine_mask(
|
||||
paddle.to_tensor(padding_mask).unsqueeze(1), paddle.to_tensor(no_future_mask)
|
||||
)
|
||||
np.testing.assert_allclose(combined_mask2.numpy(), combined_mask1)
|
|
@ -6,15 +6,6 @@ paddle.disable_static(paddle.CPUPlace())
|
|||
|
||||
from parakeet.modules import transformer
|
||||
|
||||
def sequence_mask(lengths, max_length=None, dtype="bool"):
|
||||
max_length = max_length or np.max(lengths)
|
||||
ids = np.arange(max_length)
|
||||
return (ids < np.expand_dims(lengths, -1)).astype(dtype)
|
||||
|
||||
def future_mask(lengths, max_length=None, dtype="bool"):
|
||||
max_length = max_length or np.max(lengths)
|
||||
return np.tril(np.tril(np.ones(max_length)))
|
||||
|
||||
class TestPositionwiseFFN(unittest.TestCase):
|
||||
def test_io(self):
|
||||
net = transformer.PositionwiseFFN(8, 12)
|
||||
|
@ -23,19 +14,6 @@ class TestPositionwiseFFN(unittest.TestCase):
|
|||
self.assertTupleEqual(y.numpy().shape, (2, 3, 4, 8))
|
||||
|
||||
|
||||
class TestCombineMask(unittest.TestCase):
|
||||
def test_equality(self):
|
||||
lengths = np.array([12, 8, 9, 10])
|
||||
padding_mask = sequence_mask(lengths, dtype="float64")
|
||||
no_future_mask = future_mask(lengths, dtype="float64")
|
||||
combined_mask1 = np.expand_dims(padding_mask, 1) * no_future_mask
|
||||
|
||||
combined_mask2 = transformer.combine_mask(
|
||||
paddle.to_tensor(padding_mask), paddle.to_tensor(no_future_mask)
|
||||
)
|
||||
np.testing.assert_allclose(combined_mask2.numpy(), combined_mask1)
|
||||
|
||||
|
||||
class TestTransformerEncoderLayer(unittest.TestCase):
|
||||
def test_io(self):
|
||||
net = transformer.TransformerEncoderLayer(64, 8, 128, 0.5)
|
||||
|
|
Loading…
Reference in New Issue