tested io for TransformerTTS

This commit is contained in:
iclementine 2020-10-15 16:49:14 +08:00
parent 40457227e6
commit 5270774bb0
3 changed files with 90 additions and 35 deletions

View File

@ -152,22 +152,18 @@ class TransformerDecoderLayer(nn.Layer):
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
v (Tensor): shape(batch_size, time_steps_k, d_model), values
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask.
decoder_mask (Tensor): shape(batch_size, time_steps_q) decoder padding mask.
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
Returns:
(q, self_attn_weights, cross_attn_weights)
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
"""
tq = q.shape[1]
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
combined_mask = masking.combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
"""
# pre norm
q_in = q
q = self.layer_norm1(q)
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
q = q_in + context_vector
# pre norm
@ -201,21 +197,22 @@ class TransformerDecoder(nn.LayerList):
for _ in range(n_layers):
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout))
def forward(self, x, mask):
def forward(self, q, k, v, encoder_mask, decoder_mask):
self_attention_weights = []
cross_attention_weights = []
for layer in self:
x, self_attention_weights_i, cross_attention_weights_i = layer(x, mask)
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask)
self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i)
return x, self_attention_weights, cross_attention_weights
return q, self_attention_weights, cross_attention_weights
class DecoderPreNet(nn.Layer):
def __init__(self, d_model, d_hidden, dropout):
self.lin1 = nn.Linear(d_model, d_hidden)
class MLPPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, dropout):
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
self.dropout1 = nn.Dropout(dropout)
self.lin2 = nn.Linear(d_hidden, d_model)
self.lin2 = nn.Linear(d_hidden, d_output)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
@ -223,10 +220,11 @@ class DecoderPreNet(nn.Layer):
return self.dropout2(F.relu(self.lin2(self.dropout1(F.relu(self.lin1(x))))))
class PostNet(nn.Layer):
class CNNPostNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
super(CNNPostNet, self).__init__()
self.convs = nn.LayerList()
kernel_size = kernel_size if isinstance(tuple, kernel_size) else (kernel_size, )
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
padding = (kernel_size[0] - 1, 0)
for i in range(n_layers):
c_in = d_input if i == 0 else d_hidden
@ -247,15 +245,16 @@ class TransformerTTS(nn.Layer):
def __init__(self, vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn, positional_encoding_scalar,
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
postnet_kernel_size, max_reduction_factor, dropout):
super(TransformerTTS, self).__init__()
self.encoder_prenet = nn.Embedding(vocab_size, d_model, padding_idx)
self.encoder_pe = pe.positional_encoding(0, 1000, d_model) # it may be extended later
self.encoder = TransformerEncoder(d_model, n_heads, d_ffn, encoder_layers, dropout)
self.decoder_prenet = DecoderPreNet(d_model, d_prenet, dropout)
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_model, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_model) # it may be extended later
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
self.final_proj = nn.Linear(d_model, max_reduction_factor * d_mel)
self.decoder_postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
# specs
self.padding_idx = padding_idx
@ -263,30 +262,41 @@ class TransformerTTS(nn.Layer):
self.pe_scalar = positional_encoding_scalar
def forward(self, text, mel, stop):
pass
def encode(self, text):
T_enc = text.shape[-1]
embed = self.encoder_prenet(text)
pe = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_model)) + pe.scale(self.pe_scalar)
encoder_padding_mask = masking.id_mask(text, self.padding_idx)
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
x = F.dropout(x, training=self.training)
x, attention_weights = self.encoder(x, encoder_padding_mask)
return x, attention_weights, encoder_padding_mask
def decode(self, ):
pass
def decode(self, encoder_output, input, encoder_padding_mask):
batch_size, T_dec, mel_dim = input.shape
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
decoder_input = self.decoder_prenet(input)
decoder_output, _, cross_attention_weights = self.decoder(
decoder_input,
encoder_output,
encoder_output,
encoder_padding_mask,
decoder_mask)
output_proj = self.final_proj(decoder_output)
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1])
return mel_output, mel_intermediate, cross_attention_weights
def infer(self):
pass

View File

@ -5,7 +5,7 @@ def id_mask(input, padding_index=0, dtype="bool"):
return paddle.cast(input != padding_index, dtype)
def feature_mask(input, axis, dtype="bool"):
feature_sum = paddle.sum(paddle.abs(input), axis=axis, keepdim=True)
feature_sum = paddle.sum(paddle.abs(input), axis)
return paddle.cast(feature_sum != 0, dtype)
def combine_mask(padding_mask, no_future_mask):

View File

@ -5,6 +5,8 @@ paddle.set_default_dtype("float64")
paddle.disable_static(paddle.CPUPlace())
from parakeet.models import transformer_tts as tts
from parakeet.modules import masking
from pprint import pprint
class TestMultiheadAttention(unittest.TestCase):
def test_io_same_qk(self):
@ -47,8 +49,10 @@ class TestTransformerDecoderLayer(unittest.TestCase):
v = paddle.randn([4, 24, 64])
enc_lengths = paddle.to_tensor([24, 18, 20, 22])
dec_lengths = paddle.to_tensor([32, 28, 30, 31])
enc_mask = paddle.fluid.layers.sequence_mask(enc_lengths, dtype=k.dtype)
dec_mask = paddle.fluid.layers.sequence_mask(dec_lengths, dtype=q.dtype)
enc_mask = masking.sequence_mask(enc_lengths, dtype=k.dtype)
dec_padding_mask = masking.sequence_mask(dec_lengths, dtype=q.dtype)
no_future_mask = masking.future_mask(32, dtype=q.dtype)
dec_mask = masking.combine_mask(dec_padding_mask.unsqueeze(-1), no_future_mask)
y, self_attn_weights, cross_attn_weights = net(q, k, v, enc_mask, dec_mask)
self.assertTupleEqual(y.numpy().shape, (4, 32, 64))
@ -57,8 +61,49 @@ class TestTransformerDecoderLayer(unittest.TestCase):
class TestTransformerTTS(unittest.TestCase):
def test_io(self):
return
net = tts.TransformerTTS(vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn, positional_encoding_scalar,
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
postnet_kernel_size, max_reduction_factor, dropout)
def setUp(self):
net = tts.TransformerTTS(
128, 0, 64, 80, 4, 128,
0.5,
6, 6, 128, 128, 4,
3, 10, 0.5)
self.net = net
def test_encode_io(self):
net = self.net
text = paddle.randint(0, 128, [4, 176])
lengths = paddle.to_tensor([176, 156, 174, 168])
mask = masking.sequence_mask(lengths, dtype=text.dtype)
text = text * mask
encoded, attention_weights, encoder_mask = net.encode(text)
print(encoded.numpy().shape)
print([item.shape for item in attention_weights])
print(encoder_mask.numpy().shape)
def test_all_io(self):
net = self.net
text = paddle.randint(0, 128, [4, 176])
lengths = paddle.to_tensor([176, 156, 174, 168])
mask = masking.sequence_mask(lengths, dtype=text.dtype)
text = text * mask
mel = paddle.randn([4, 189, 80])
frames = paddle.to_tensor([189, 186, 179, 174])
mask = masking.sequence_mask(frames, dtype=frames.dtype)
mel = mel * mask.unsqueeze(-1)
encoded, encoder_attention_weights, encoder_mask = net.encode(text)
mel_output, mel_intermediate, cross_attention_weights = net.decode(encoded, mel, encoder_mask)
print("output shapes:")
print("encoder_output:", encoded.numpy().shape)
print("encoder_attentions:", [item.shape for item in encoder_attention_weights])
print("encoder_mask:", encoder_mask.numpy().shape)
print("mel_output: ", mel_output.numpy().shape)
print("mel_intermediate: ", mel_intermediate.numpy().shape)
print("decoder_attentions:", [item.shape for item in cross_attention_weights])