tested io for TransformerTTS
This commit is contained in:
parent
40457227e6
commit
5270774bb0
|
@ -152,22 +152,18 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
|
||||
v (Tensor): shape(batch_size, time_steps_k, d_model), values
|
||||
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask.
|
||||
decoder_mask (Tensor): shape(batch_size, time_steps_q) decoder padding mask.
|
||||
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
|
||||
|
||||
Returns:
|
||||
(q, self_attn_weights, cross_attn_weights)
|
||||
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
|
||||
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
||||
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
||||
"""
|
||||
tq = q.shape[1]
|
||||
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
|
||||
combined_mask = masking.combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
|
||||
|
||||
"""
|
||||
# pre norm
|
||||
q_in = q
|
||||
q = self.layer_norm1(q)
|
||||
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
|
||||
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
|
||||
q = q_in + context_vector
|
||||
|
||||
# pre norm
|
||||
|
@ -201,21 +197,22 @@ class TransformerDecoder(nn.LayerList):
|
|||
for _ in range(n_layers):
|
||||
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout))
|
||||
|
||||
def forward(self, x, mask):
|
||||
def forward(self, q, k, v, encoder_mask, decoder_mask):
|
||||
self_attention_weights = []
|
||||
cross_attention_weights = []
|
||||
for layer in self:
|
||||
x, self_attention_weights_i, cross_attention_weights_i = layer(x, mask)
|
||||
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask)
|
||||
self_attention_weights.append(self_attention_weights_i)
|
||||
cross_attention_weights.append(cross_attention_weights_i)
|
||||
return x, self_attention_weights, cross_attention_weights
|
||||
return q, self_attention_weights, cross_attention_weights
|
||||
|
||||
|
||||
class DecoderPreNet(nn.Layer):
|
||||
def __init__(self, d_model, d_hidden, dropout):
|
||||
self.lin1 = nn.Linear(d_model, d_hidden)
|
||||
class MLPPreNet(nn.Layer):
|
||||
def __init__(self, d_input, d_hidden, d_output, dropout):
|
||||
super(MLPPreNet, self).__init__()
|
||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.lin2 = nn.Linear(d_hidden, d_model)
|
||||
self.lin2 = nn.Linear(d_hidden, d_output)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -223,10 +220,11 @@ class DecoderPreNet(nn.Layer):
|
|||
return self.dropout2(F.relu(self.lin2(self.dropout1(F.relu(self.lin1(x))))))
|
||||
|
||||
|
||||
class PostNet(nn.Layer):
|
||||
class CNNPostNet(nn.Layer):
|
||||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
||||
super(CNNPostNet, self).__init__()
|
||||
self.convs = nn.LayerList()
|
||||
kernel_size = kernel_size if isinstance(tuple, kernel_size) else (kernel_size, )
|
||||
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
|
||||
padding = (kernel_size[0] - 1, 0)
|
||||
for i in range(n_layers):
|
||||
c_in = d_input if i == 0 else d_hidden
|
||||
|
@ -247,15 +245,16 @@ class TransformerTTS(nn.Layer):
|
|||
def __init__(self, vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn, positional_encoding_scalar,
|
||||
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
||||
postnet_kernel_size, max_reduction_factor, dropout):
|
||||
super(TransformerTTS, self).__init__()
|
||||
self.encoder_prenet = nn.Embedding(vocab_size, d_model, padding_idx)
|
||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_model) # it may be extended later
|
||||
self.encoder = TransformerEncoder(d_model, n_heads, d_ffn, encoder_layers, dropout)
|
||||
|
||||
self.decoder_prenet = DecoderPreNet(d_model, d_prenet, dropout)
|
||||
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_model, dropout)
|
||||
self.decoder_pe = pe.positional_encoding(0, 1000, d_model) # it may be extended later
|
||||
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
|
||||
self.final_proj = nn.Linear(d_model, max_reduction_factor * d_mel)
|
||||
self.decoder_postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
|
||||
# specs
|
||||
self.padding_idx = padding_idx
|
||||
|
@ -263,30 +262,41 @@ class TransformerTTS(nn.Layer):
|
|||
self.pe_scalar = positional_encoding_scalar
|
||||
|
||||
|
||||
|
||||
def forward(self, text, mel, stop):
|
||||
pass
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
T_enc = text.shape[-1]
|
||||
embed = self.encoder_prenet(text)
|
||||
pe = self.encoder_pe[:T_enc, :] # (T, C)
|
||||
x = embed.scale(math.sqrt(self.d_model)) + pe.scale(self.pe_scalar)
|
||||
encoder_padding_mask = masking.id_mask(text, self.padding_idx)
|
||||
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
|
||||
|
||||
x = F.dropout(x, training=self.training)
|
||||
x, attention_weights = self.encoder(x, encoder_padding_mask)
|
||||
return x, attention_weights, encoder_padding_mask
|
||||
|
||||
def decode(self, ):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def decode(self, encoder_output, input, encoder_padding_mask):
|
||||
batch_size, T_dec, mel_dim = input.shape
|
||||
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
||||
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
|
||||
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
||||
|
||||
decoder_input = self.decoder_prenet(input)
|
||||
decoder_output, _, cross_attention_weights = self.decoder(
|
||||
decoder_input,
|
||||
encoder_output,
|
||||
encoder_output,
|
||||
encoder_padding_mask,
|
||||
decoder_mask)
|
||||
|
||||
output_proj = self.final_proj(decoder_output)
|
||||
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
||||
|
||||
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
||||
mel_output = self.decoder_postnet(mel_channel_first)
|
||||
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
||||
return mel_output, mel_intermediate, cross_attention_weights
|
||||
|
||||
def infer(self):
|
||||
pass
|
|
@ -5,7 +5,7 @@ def id_mask(input, padding_index=0, dtype="bool"):
|
|||
return paddle.cast(input != padding_index, dtype)
|
||||
|
||||
def feature_mask(input, axis, dtype="bool"):
|
||||
feature_sum = paddle.sum(paddle.abs(input), axis=axis, keepdim=True)
|
||||
feature_sum = paddle.sum(paddle.abs(input), axis)
|
||||
return paddle.cast(feature_sum != 0, dtype)
|
||||
|
||||
def combine_mask(padding_mask, no_future_mask):
|
||||
|
|
|
@ -5,6 +5,8 @@ paddle.set_default_dtype("float64")
|
|||
paddle.disable_static(paddle.CPUPlace())
|
||||
|
||||
from parakeet.models import transformer_tts as tts
|
||||
from parakeet.modules import masking
|
||||
from pprint import pprint
|
||||
|
||||
class TestMultiheadAttention(unittest.TestCase):
|
||||
def test_io_same_qk(self):
|
||||
|
@ -47,8 +49,10 @@ class TestTransformerDecoderLayer(unittest.TestCase):
|
|||
v = paddle.randn([4, 24, 64])
|
||||
enc_lengths = paddle.to_tensor([24, 18, 20, 22])
|
||||
dec_lengths = paddle.to_tensor([32, 28, 30, 31])
|
||||
enc_mask = paddle.fluid.layers.sequence_mask(enc_lengths, dtype=k.dtype)
|
||||
dec_mask = paddle.fluid.layers.sequence_mask(dec_lengths, dtype=q.dtype)
|
||||
enc_mask = masking.sequence_mask(enc_lengths, dtype=k.dtype)
|
||||
dec_padding_mask = masking.sequence_mask(dec_lengths, dtype=q.dtype)
|
||||
no_future_mask = masking.future_mask(32, dtype=q.dtype)
|
||||
dec_mask = masking.combine_mask(dec_padding_mask.unsqueeze(-1), no_future_mask)
|
||||
y, self_attn_weights, cross_attn_weights = net(q, k, v, enc_mask, dec_mask)
|
||||
|
||||
self.assertTupleEqual(y.numpy().shape, (4, 32, 64))
|
||||
|
@ -57,8 +61,49 @@ class TestTransformerDecoderLayer(unittest.TestCase):
|
|||
|
||||
|
||||
class TestTransformerTTS(unittest.TestCase):
|
||||
def test_io(self):
|
||||
return
|
||||
net = tts.TransformerTTS(vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn, positional_encoding_scalar,
|
||||
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
||||
postnet_kernel_size, max_reduction_factor, dropout)
|
||||
def setUp(self):
|
||||
net = tts.TransformerTTS(
|
||||
128, 0, 64, 80, 4, 128,
|
||||
0.5,
|
||||
6, 6, 128, 128, 4,
|
||||
3, 10, 0.5)
|
||||
self.net = net
|
||||
|
||||
def test_encode_io(self):
|
||||
net = self.net
|
||||
|
||||
text = paddle.randint(0, 128, [4, 176])
|
||||
lengths = paddle.to_tensor([176, 156, 174, 168])
|
||||
mask = masking.sequence_mask(lengths, dtype=text.dtype)
|
||||
text = text * mask
|
||||
|
||||
encoded, attention_weights, encoder_mask = net.encode(text)
|
||||
print(encoded.numpy().shape)
|
||||
print([item.shape for item in attention_weights])
|
||||
print(encoder_mask.numpy().shape)
|
||||
|
||||
def test_all_io(self):
|
||||
net = self.net
|
||||
|
||||
text = paddle.randint(0, 128, [4, 176])
|
||||
lengths = paddle.to_tensor([176, 156, 174, 168])
|
||||
mask = masking.sequence_mask(lengths, dtype=text.dtype)
|
||||
text = text * mask
|
||||
|
||||
mel = paddle.randn([4, 189, 80])
|
||||
frames = paddle.to_tensor([189, 186, 179, 174])
|
||||
mask = masking.sequence_mask(frames, dtype=frames.dtype)
|
||||
mel = mel * mask.unsqueeze(-1)
|
||||
|
||||
encoded, encoder_attention_weights, encoder_mask = net.encode(text)
|
||||
mel_output, mel_intermediate, cross_attention_weights = net.decode(encoded, mel, encoder_mask)
|
||||
|
||||
print("output shapes:")
|
||||
print("encoder_output:", encoded.numpy().shape)
|
||||
print("encoder_attentions:", [item.shape for item in encoder_attention_weights])
|
||||
print("encoder_mask:", encoder_mask.numpy().shape)
|
||||
print("mel_output: ", mel_output.numpy().shape)
|
||||
print("mel_intermediate: ", mel_intermediate.numpy().shape)
|
||||
print("decoder_attentions:", [item.shape for item in cross_attention_weights])
|
||||
|
||||
|
Loading…
Reference in New Issue