minor fixes to TransformerTTS

This commit is contained in:
chenfeiyu 2020-10-28 11:05:47 +08:00
parent c43216ae9b
commit 36cc543348
2 changed files with 164 additions and 52 deletions

View File

@ -9,8 +9,9 @@ from parakeet.modules.transformer import PositionwiseFFN
from parakeet.modules import masking
from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules import positional_encoding as pe
from parakeet.modules import losses as L
__all__ = ["TransformerTTS"]
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
# Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer):
@ -101,11 +102,29 @@ class TransformerEncoderLayer(nn.Layer):
super(TransformerEncoderLayer, self).__init__()
self.self_mha = MultiheadAttention(d_model, n_heads)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
self.dropout = dropout
def _forward_mha(self, x, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
def _forward_ffn(self, x):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm2(x)
x = self.ffn(x)
out= x_in + F.dropout(x, self.dropout, training=self.training)
return out
def forward(self, x, mask):
def forward(self, x, mask, drop_n_heads=0):
"""
Args:
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
@ -116,14 +135,16 @@ class TransformerEncoderLayer(nn.Layer):
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
"""
# pre norm
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1))
x = x_in + context_vector # here, the order can be tuned
# # pre norm
# x_in = x
# x = self.layer_norm1(x)
# context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
# x = x_in + context_vector # here, the order can be tuned
# pre norm
x = x + self.ffn(self.layer_norm2(x))
# # pre norm
# x = x + self.ffn(self.layer_norm2(x))
x, attn_weights = self._forward_mha(x, mask, drop_n_heads)
x = self._forward_ffn(x)
return x, attn_weights
@ -149,8 +170,34 @@ class TransformerDecoderLayer(nn.Layer):
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
self.dropout = dropout
def forward(self, q, k, v, encoder_mask, decoder_mask):
def _forward_self_mha(self, x, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
q_in = q
q = self.layer_norm2(q)
context_vector, attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(mask, 1), drop_n_heads)
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
def _forward_ffn(self, x):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm3(x)
x = self.ffn(x)
out= x_in + F.dropout(x, self.dropout, training=self.training)
return out
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
"""
Args:
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
@ -165,20 +212,23 @@ class TransformerDecoderLayer(nn.Layer):
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
"""
# pre norm
q_in = q
q = self.layer_norm1(q)
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
q = q_in + context_vector
# # pre norm
# q_in = q
# q = self.layer_norm1(q)
# context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask, drop_n_heads)
# q = q_in + context_vector
# pre norm
q_in = q
q = self.layer_norm2(q)
context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1))
q = q_in + context_vector
# # pre norm
# q_in = q
# q = self.layer_norm2(q)
# context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1), drop_n_heads)
# q = q_in + context_vector
# pre norm
q = q + self.ffn(self.layer_norm3(q))
# # pre norm
# q = q + self.ffn(self.layer_norm3(q))
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
q = self._forward_ffn(q)
return q, self_attn_weights, cross_attn_weights
@ -188,10 +238,21 @@ class TransformerEncoder(nn.LayerList):
for _ in range(n_layers):
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
def forward(self, x, mask):
def forward(self, x, mask, drop_n_heads=0):
"""
Args:
x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor.
mask (Tensor): shape(batch_size, time_steps), the mask.
drop_n_heads (int, optional): how many heads to drop. Defaults to 0.
Returns:
x (Tensor): shape(batch_size, time_steps, feature_size), the context vector.
attention_weights(list), list of tensors, each of shape
(batch_size, n_heads, time_steps, time_steps), the attention weights.
"""
attention_weights = []
for layer in self:
x, attention_weights_i = layer(x, mask)
x, attention_weights_i = layer(x, mask, drop_n_heads)
attention_weights.append(attention_weights_i)
return x, attention_weights
@ -202,27 +263,40 @@ class TransformerDecoder(nn.LayerList):
for _ in range(n_layers):
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
def forward(self, q, k, v, encoder_mask, decoder_mask):
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
"""[summary]
Args:
q (Tensor): shape(batch_size, time_steps_q, d_model)
k (Tensor): shape(batch_size, time_steps_k, d_encoder)
v (Tensor): shape(batch_size, time_steps_k, k_encoder)
encoder_mask (Tensor): shape(batch_size, time_steps_k)
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q)
drop_n_heads (int, optional): [description]. Defaults to 0.
Returns:
[type]: [description]
"""
self_attention_weights = []
cross_attention_weights = []
for layer in self:
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask)
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads)
self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i)
return q, self_attention_weights, cross_attention_weights
class MLPPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, dropout):
def __init__(self, d_input, d_hidden, d_output):
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
self.dropout1 = nn.Dropout(dropout)
self.lin2 = nn.Linear(d_hidden, d_output)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
def forward(self, x, dropout):
# the original code said also use dropout in inference
return self.dropout2(F.relu(self.lin2(self.dropout1(F.relu(self.lin1(x))))))
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
return l2
class CNNPostNet(nn.Layer):
@ -251,17 +325,16 @@ class TransformerTTS(nn.Layer):
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
postnet_kernel_size, max_reduction_factor, dropout):
super(TransformerTTS, self).__init__()
# initial pe scalar is 1, though it is trainable
self.pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
# encoder
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx)
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
# decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later
self.decoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder)
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
@ -271,51 +344,69 @@ class TransformerTTS(nn.Layer):
self.padding_idx = padding_idx
self.d_encoder = d_encoder
self.d_decoder = d_decoder
self.max_r = max_reduction_factor
self.dropout = dropout
# start and end: though it is only used in predict
# it can also be used in training
dtype = paddle.get_default_dtype()
self.start_vec = paddle.full([1, d_mel], 0, dtype=dtype)
self.end_vec = paddle.full([1, d_mel], 0, dtype=dtype)
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
self.stop_prob_index = 2
self.max_r = max_reduction_factor
# mutables
self.r = max_reduction_factor # set it every call
self.decoder_prenet_dropout = 0.0
self.drop_n_heads = 0
def forward(self, text, mel, stop):
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights, stop_logits
def encode(self, text):
T_enc = text.shape[-1]
embed = self.encoder_prenet(text)
pe = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_encoder)) + pe * self.pe_scalar
if embed.shape[1] > self.encoder_pe.shape[0]:
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training)
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
x = F.dropout(x, training=self.training)
x, attention_weights = self.encoder(x, encoder_padding_mask)
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
return x, attention_weights, encoder_padding_mask
def decode(self, encoder_output, input, encoder_padding_mask):
batch_size, T_dec, mel_dim = input.shape
x = self.decoder_prenet(input, self.decoder_prenet_dropout)
# twice its length if needed
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
pos_enc = self.decoder_pe[:T_dec*self.r:self.r, :]
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training)
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
decoder_input = self.decoder_prenet(input)
decoder_output, _, cross_attention_weights = self.decoder(
decoder_input,
x,
encoder_output,
encoder_output,
encoder_padding_mask,
decoder_mask)
decoder_mask,
self.drop_n_heads)
output_proj = self.final_proj(decoder_output)
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate)
# cnn postnet
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1])
@ -352,12 +443,33 @@ class TransformerTTS(nn.Layer):
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
def set_constants(self, reduction_factor, drop_n_heads, decoder_prenet_dropout):
# TODO(chenfeiyu): make a good design for these hyperparameter settings
self.r = reduction_factor
self.drop_n_heads = drop_n_heads
self.decoder_prenet_dropout = decoder_prenet_dropout
class TransformerTTSLoss(nn.Layer):
def __init__(self, stop_loss_scale):
super(TransformerTTSLoss, self).__init__()
self.stop_loss_scale = stop_loss_scale
def forward(self, ):
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
mel_len = mask.shape[-1]
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss
details = dict(
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss
)
return loss, details

View File

@ -65,7 +65,7 @@ class TestTransformerTTS(unittest.TestCase):
net = tts.TransformerTTS(
128, 0, 64, 128, 80, 4, 128,
6, 6, 128, 128, 4,
3, 10, 0.5)
3, 10, 0.1)
self.net = net
def test_encode_io(self):