minor fixes to TransformerTTS
This commit is contained in:
parent
c43216ae9b
commit
36cc543348
|
@ -9,8 +9,9 @@ from parakeet.modules.transformer import PositionwiseFFN
|
||||||
from parakeet.modules import masking
|
from parakeet.modules import masking
|
||||||
from parakeet.modules.conv import Conv1dBatchNorm
|
from parakeet.modules.conv import Conv1dBatchNorm
|
||||||
from parakeet.modules import positional_encoding as pe
|
from parakeet.modules import positional_encoding as pe
|
||||||
|
from parakeet.modules import losses as L
|
||||||
|
|
||||||
__all__ = ["TransformerTTS"]
|
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
|
||||||
|
|
||||||
# Transformer TTS's own implementation of transformer
|
# Transformer TTS's own implementation of transformer
|
||||||
class MultiheadAttention(nn.Layer):
|
class MultiheadAttention(nn.Layer):
|
||||||
|
@ -101,11 +102,29 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
self.self_mha = MultiheadAttention(d_model, n_heads)
|
self.self_mha = MultiheadAttention(d_model, n_heads)
|
||||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||||
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def _forward_mha(self, x, mask, drop_n_heads):
|
||||||
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
|
x_in = x
|
||||||
|
x = self.layer_norm1(x)
|
||||||
|
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
|
||||||
|
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||||
|
return context_vector, attn_weights
|
||||||
|
|
||||||
|
def _forward_ffn(self, x):
|
||||||
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
|
x_in = x
|
||||||
|
x = self.layer_norm2(x)
|
||||||
|
x = self.ffn(x)
|
||||||
|
out= x_in + F.dropout(x, self.dropout, training=self.training)
|
||||||
|
return out
|
||||||
|
|
||||||
def forward(self, x, mask):
|
def forward(self, x, mask, drop_n_heads=0):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
|
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
|
||||||
|
@ -116,14 +135,16 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
|
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
|
||||||
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
|
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
|
||||||
"""
|
"""
|
||||||
# pre norm
|
# # pre norm
|
||||||
x_in = x
|
# x_in = x
|
||||||
x = self.layer_norm1(x)
|
# x = self.layer_norm1(x)
|
||||||
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1))
|
# context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
|
||||||
x = x_in + context_vector # here, the order can be tuned
|
# x = x_in + context_vector # here, the order can be tuned
|
||||||
|
|
||||||
# pre norm
|
# # pre norm
|
||||||
x = x + self.ffn(self.layer_norm2(x))
|
# x = x + self.ffn(self.layer_norm2(x))
|
||||||
|
x, attn_weights = self._forward_mha(x, mask, drop_n_heads)
|
||||||
|
x = self._forward_ffn(x)
|
||||||
return x, attn_weights
|
return x, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,8 +170,34 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
|
|
||||||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||||
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, q, k, v, encoder_mask, decoder_mask):
|
def _forward_self_mha(self, x, mask, drop_n_heads):
|
||||||
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
|
x_in = x
|
||||||
|
x = self.layer_norm1(x)
|
||||||
|
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
||||||
|
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||||
|
return context_vector, attn_weights
|
||||||
|
|
||||||
|
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
|
||||||
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
|
q_in = q
|
||||||
|
q = self.layer_norm2(q)
|
||||||
|
context_vector, attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(mask, 1), drop_n_heads)
|
||||||
|
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||||
|
return context_vector, attn_weights
|
||||||
|
|
||||||
|
def _forward_ffn(self, x):
|
||||||
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
|
x_in = x
|
||||||
|
x = self.layer_norm3(x)
|
||||||
|
x = self.ffn(x)
|
||||||
|
out= x_in + F.dropout(x, self.dropout, training=self.training)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
|
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
|
||||||
|
@ -165,20 +212,23 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
||||||
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
||||||
"""
|
"""
|
||||||
# pre norm
|
# # pre norm
|
||||||
q_in = q
|
# q_in = q
|
||||||
q = self.layer_norm1(q)
|
# q = self.layer_norm1(q)
|
||||||
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
|
# context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask, drop_n_heads)
|
||||||
q = q_in + context_vector
|
# q = q_in + context_vector
|
||||||
|
|
||||||
# pre norm
|
# # pre norm
|
||||||
q_in = q
|
# q_in = q
|
||||||
q = self.layer_norm2(q)
|
# q = self.layer_norm2(q)
|
||||||
context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1))
|
# context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1), drop_n_heads)
|
||||||
q = q_in + context_vector
|
# q = q_in + context_vector
|
||||||
|
|
||||||
# pre norm
|
# # pre norm
|
||||||
q = q + self.ffn(self.layer_norm3(q))
|
# q = q + self.ffn(self.layer_norm3(q))
|
||||||
|
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
|
||||||
|
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
|
||||||
|
q = self._forward_ffn(q)
|
||||||
return q, self_attn_weights, cross_attn_weights
|
return q, self_attn_weights, cross_attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,10 +238,21 @@ class TransformerEncoder(nn.LayerList):
|
||||||
for _ in range(n_layers):
|
for _ in range(n_layers):
|
||||||
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
||||||
|
|
||||||
def forward(self, x, mask):
|
def forward(self, x, mask, drop_n_heads=0):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor.
|
||||||
|
mask (Tensor): shape(batch_size, time_steps), the mask.
|
||||||
|
drop_n_heads (int, optional): how many heads to drop. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (Tensor): shape(batch_size, time_steps, feature_size), the context vector.
|
||||||
|
attention_weights(list), list of tensors, each of shape
|
||||||
|
(batch_size, n_heads, time_steps, time_steps), the attention weights.
|
||||||
|
"""
|
||||||
attention_weights = []
|
attention_weights = []
|
||||||
for layer in self:
|
for layer in self:
|
||||||
x, attention_weights_i = layer(x, mask)
|
x, attention_weights_i = layer(x, mask, drop_n_heads)
|
||||||
attention_weights.append(attention_weights_i)
|
attention_weights.append(attention_weights_i)
|
||||||
return x, attention_weights
|
return x, attention_weights
|
||||||
|
|
||||||
|
@ -202,27 +263,40 @@ class TransformerDecoder(nn.LayerList):
|
||||||
for _ in range(n_layers):
|
for _ in range(n_layers):
|
||||||
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
||||||
|
|
||||||
def forward(self, q, k, v, encoder_mask, decoder_mask):
|
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
||||||
|
"""[summary]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (Tensor): shape(batch_size, time_steps_q, d_model)
|
||||||
|
k (Tensor): shape(batch_size, time_steps_k, d_encoder)
|
||||||
|
v (Tensor): shape(batch_size, time_steps_k, k_encoder)
|
||||||
|
encoder_mask (Tensor): shape(batch_size, time_steps_k)
|
||||||
|
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q)
|
||||||
|
drop_n_heads (int, optional): [description]. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type]: [description]
|
||||||
|
"""
|
||||||
self_attention_weights = []
|
self_attention_weights = []
|
||||||
cross_attention_weights = []
|
cross_attention_weights = []
|
||||||
for layer in self:
|
for layer in self:
|
||||||
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask)
|
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads)
|
||||||
self_attention_weights.append(self_attention_weights_i)
|
self_attention_weights.append(self_attention_weights_i)
|
||||||
cross_attention_weights.append(cross_attention_weights_i)
|
cross_attention_weights.append(cross_attention_weights_i)
|
||||||
return q, self_attention_weights, cross_attention_weights
|
return q, self_attention_weights, cross_attention_weights
|
||||||
|
|
||||||
|
|
||||||
class MLPPreNet(nn.Layer):
|
class MLPPreNet(nn.Layer):
|
||||||
def __init__(self, d_input, d_hidden, d_output, dropout):
|
def __init__(self, d_input, d_hidden, d_output):
|
||||||
super(MLPPreNet, self).__init__()
|
super(MLPPreNet, self).__init__()
|
||||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.lin2 = nn.Linear(d_hidden, d_output)
|
self.lin2 = nn.Linear(d_hidden, d_output)
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, dropout):
|
||||||
# the original code said also use dropout in inference
|
# the original code said also use dropout in inference
|
||||||
return self.dropout2(F.relu(self.lin2(self.dropout1(F.relu(self.lin1(x))))))
|
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
|
||||||
|
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
|
||||||
|
return l2
|
||||||
|
|
||||||
|
|
||||||
class CNNPostNet(nn.Layer):
|
class CNNPostNet(nn.Layer):
|
||||||
|
@ -251,17 +325,16 @@ class TransformerTTS(nn.Layer):
|
||||||
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
||||||
postnet_kernel_size, max_reduction_factor, dropout):
|
postnet_kernel_size, max_reduction_factor, dropout):
|
||||||
super(TransformerTTS, self).__init__()
|
super(TransformerTTS, self).__init__()
|
||||||
# initial pe scalar is 1, though it is trainable
|
|
||||||
self.pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx)
|
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx)
|
||||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
|
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
|
||||||
|
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
||||||
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder)
|
||||||
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later
|
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later
|
||||||
|
self.decoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
||||||
self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder)
|
self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder)
|
||||||
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
||||||
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||||
|
@ -271,51 +344,69 @@ class TransformerTTS(nn.Layer):
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.d_encoder = d_encoder
|
self.d_encoder = d_encoder
|
||||||
self.d_decoder = d_decoder
|
self.d_decoder = d_decoder
|
||||||
|
self.max_r = max_reduction_factor
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
# start and end: though it is only used in predict
|
# start and end: though it is only used in predict
|
||||||
# it can also be used in training
|
# it can also be used in training
|
||||||
dtype = paddle.get_default_dtype()
|
dtype = paddle.get_default_dtype()
|
||||||
self.start_vec = paddle.full([1, d_mel], 0, dtype=dtype)
|
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
|
||||||
self.end_vec = paddle.full([1, d_mel], 0, dtype=dtype)
|
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
|
||||||
self.stop_prob_index = 2
|
self.stop_prob_index = 2
|
||||||
|
|
||||||
self.max_r = max_reduction_factor
|
# mutables
|
||||||
self.r = max_reduction_factor # set it every call
|
self.r = max_reduction_factor # set it every call
|
||||||
|
self.decoder_prenet_dropout = 0.0
|
||||||
|
self.drop_n_heads = 0
|
||||||
|
|
||||||
def forward(self, text, mel, stop):
|
def forward(self, text, mel, stop):
|
||||||
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
||||||
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
|
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
|
||||||
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights
|
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights, stop_logits
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
T_enc = text.shape[-1]
|
T_enc = text.shape[-1]
|
||||||
embed = self.encoder_prenet(text)
|
embed = self.encoder_prenet(text)
|
||||||
pe = self.encoder_pe[:T_enc, :] # (T, C)
|
if embed.shape[1] > self.encoder_pe.shape[0]:
|
||||||
x = embed.scale(math.sqrt(self.d_encoder)) + pe * self.pe_scalar
|
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
||||||
|
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
||||||
|
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
|
||||||
|
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar
|
||||||
|
x = F.dropout(x, self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
|
||||||
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
|
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
|
||||||
|
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
|
||||||
x = F.dropout(x, training=self.training)
|
|
||||||
x, attention_weights = self.encoder(x, encoder_padding_mask)
|
|
||||||
return x, attention_weights, encoder_padding_mask
|
return x, attention_weights, encoder_padding_mask
|
||||||
|
|
||||||
def decode(self, encoder_output, input, encoder_padding_mask):
|
def decode(self, encoder_output, input, encoder_padding_mask):
|
||||||
batch_size, T_dec, mel_dim = input.shape
|
batch_size, T_dec, mel_dim = input.shape
|
||||||
|
|
||||||
|
x = self.decoder_prenet(input, self.decoder_prenet_dropout)
|
||||||
|
# twice its length if needed
|
||||||
|
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
|
||||||
|
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
||||||
|
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
|
||||||
|
pos_enc = self.decoder_pe[:T_dec*self.r:self.r, :]
|
||||||
|
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
||||||
|
x = F.dropout(x, self.dropout, training=self.training)
|
||||||
|
|
||||||
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
||||||
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
|
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
|
||||||
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
||||||
|
|
||||||
decoder_input = self.decoder_prenet(input)
|
|
||||||
decoder_output, _, cross_attention_weights = self.decoder(
|
decoder_output, _, cross_attention_weights = self.decoder(
|
||||||
decoder_input,
|
x,
|
||||||
encoder_output,
|
encoder_output,
|
||||||
encoder_output,
|
encoder_output,
|
||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
decoder_mask)
|
decoder_mask,
|
||||||
|
self.drop_n_heads)
|
||||||
|
|
||||||
output_proj = self.final_proj(decoder_output)
|
output_proj = self.final_proj(decoder_output)
|
||||||
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
||||||
stop_logits = self.stop_conditioner(mel_intermediate)
|
stop_logits = self.stop_conditioner(mel_intermediate)
|
||||||
|
|
||||||
|
# cnn postnet
|
||||||
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
||||||
mel_output = self.decoder_postnet(mel_channel_first)
|
mel_output = self.decoder_postnet(mel_channel_first)
|
||||||
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
||||||
|
@ -352,12 +443,33 @@ class TransformerTTS(nn.Layer):
|
||||||
|
|
||||||
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
|
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
|
||||||
|
|
||||||
|
def set_constants(self, reduction_factor, drop_n_heads, decoder_prenet_dropout):
|
||||||
|
# TODO(chenfeiyu): make a good design for these hyperparameter settings
|
||||||
|
self.r = reduction_factor
|
||||||
|
self.drop_n_heads = drop_n_heads
|
||||||
|
self.decoder_prenet_dropout = decoder_prenet_dropout
|
||||||
|
|
||||||
|
|
||||||
class TransformerTTSLoss(nn.Layer):
|
class TransformerTTSLoss(nn.Layer):
|
||||||
def __init__(self, stop_loss_scale):
|
def __init__(self, stop_loss_scale):
|
||||||
super(TransformerTTSLoss, self).__init__()
|
super(TransformerTTSLoss, self).__init__()
|
||||||
self.stop_loss_scale = stop_loss_scale
|
self.stop_loss_scale = stop_loss_scale
|
||||||
|
|
||||||
def forward(self, ):
|
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
||||||
|
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
||||||
|
mask1 = paddle.unsqueeze(mask, -1)
|
||||||
|
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||||
|
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||||
|
|
||||||
|
mel_len = mask.shape[-1]
|
||||||
|
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
||||||
|
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
|
||||||
|
stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||||
|
|
||||||
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||||
|
details = dict(
|
||||||
|
mel_loss1=mel_loss1, # ouput mel loss
|
||||||
|
mel_loss2=mel_loss2, # intermediate mel loss
|
||||||
|
stop_loss=stop_loss # stop prob loss
|
||||||
|
)
|
||||||
return loss, details
|
return loss, details
|
|
@ -65,7 +65,7 @@ class TestTransformerTTS(unittest.TestCase):
|
||||||
net = tts.TransformerTTS(
|
net = tts.TransformerTTS(
|
||||||
128, 0, 64, 128, 80, 4, 128,
|
128, 0, 64, 128, 80, 4, 128,
|
||||||
6, 6, 128, 128, 4,
|
6, 6, 128, 128, 4,
|
||||||
3, 10, 0.5)
|
3, 10, 0.1)
|
||||||
self.net = net
|
self.net = net
|
||||||
|
|
||||||
def test_encode_io(self):
|
def test_encode_io(self):
|
||||||
|
|
Loading…
Reference in New Issue