|
|
|
@ -1,9 +1,11 @@
|
|
|
|
|
import math
|
|
|
|
|
from tqdm import trange
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import nn
|
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
|
|
import parakeet
|
|
|
|
|
from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention
|
|
|
|
|
from parakeet.modules.transformer import PositionwiseFFN
|
|
|
|
|
from parakeet.modules import masking
|
|
|
|
@ -111,8 +113,6 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
|
|
|
|
|
|
def _forward_mha(self, x, mask, drop_n_heads):
|
|
|
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
|
|
|
if mask is not None:
|
|
|
|
|
mask = paddle.unsqueeze(mask, 1)
|
|
|
|
|
x_in = x
|
|
|
|
|
x = self.layer_norm1(x)
|
|
|
|
|
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
|
|
|
@ -131,21 +131,13 @@ class TransformerEncoderLayer(nn.Layer):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
|
|
|
|
|
mask (Tensor): shape(batch_size, time_steps), the padding mask.
|
|
|
|
|
mask (Tensor): shape(batch_size, 1, time_steps), the padding mask.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(x, attn_weights)
|
|
|
|
|
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
|
|
|
|
|
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
|
|
|
|
|
"""
|
|
|
|
|
# # pre norm
|
|
|
|
|
# x_in = x
|
|
|
|
|
# x = self.layer_norm1(x)
|
|
|
|
|
# context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
|
|
|
|
|
# x = x_in + context_vector # here, the order can be tuned
|
|
|
|
|
|
|
|
|
|
# # pre norm
|
|
|
|
|
# x = x + self.ffn(self.layer_norm2(x))
|
|
|
|
|
x, attn_weights = self._forward_mha(x, mask, drop_n_heads)
|
|
|
|
|
x = self._forward_ffn(x)
|
|
|
|
|
return x, attn_weights
|
|
|
|
@ -188,7 +180,7 @@ class TransformerDecoderLayer(nn.Layer):
|
|
|
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
|
|
|
q_in = q
|
|
|
|
|
q = self.layer_norm2(q)
|
|
|
|
|
context_vector, attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(mask, 1), drop_n_heads)
|
|
|
|
|
context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads)
|
|
|
|
|
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
|
|
|
|
|
return context_vector, attn_weights
|
|
|
|
|
|
|
|
|
@ -206,7 +198,7 @@ class TransformerDecoderLayer(nn.Layer):
|
|
|
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
|
|
|
|
|
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
|
|
|
|
|
v (Tensor): shape(batch_size, time_steps_k, d_model), values
|
|
|
|
|
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask.
|
|
|
|
|
encoder_mask (Tensor): shape(batch_size, 1, time_steps_k) encoder padding mask.
|
|
|
|
|
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
@ -214,21 +206,7 @@ class TransformerDecoderLayer(nn.Layer):
|
|
|
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
|
|
|
|
|
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
|
|
|
|
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
|
|
|
|
"""
|
|
|
|
|
# # pre norm
|
|
|
|
|
# q_in = q
|
|
|
|
|
# q = self.layer_norm1(q)
|
|
|
|
|
# context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask, drop_n_heads)
|
|
|
|
|
# q = q_in + context_vector
|
|
|
|
|
|
|
|
|
|
# # pre norm
|
|
|
|
|
# q_in = q
|
|
|
|
|
# q = self.layer_norm2(q)
|
|
|
|
|
# context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1), drop_n_heads)
|
|
|
|
|
# q = q_in + context_vector
|
|
|
|
|
|
|
|
|
|
# # pre norm
|
|
|
|
|
# q = q + self.ffn(self.layer_norm3(q))
|
|
|
|
|
"""
|
|
|
|
|
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
|
|
|
|
|
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
|
|
|
|
|
q = self._forward_ffn(q)
|
|
|
|
@ -245,7 +223,7 @@ class TransformerEncoder(nn.LayerList):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor.
|
|
|
|
|
mask (Tensor): shape(batch_size, time_steps), the mask.
|
|
|
|
|
mask (Tensor): shape(batch_size, 1, time_steps), the mask.
|
|
|
|
|
drop_n_heads (int, optional): how many heads to drop. Defaults to 0.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
@ -273,7 +251,7 @@ class TransformerDecoder(nn.LayerList):
|
|
|
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model)
|
|
|
|
|
k (Tensor): shape(batch_size, time_steps_k, d_encoder)
|
|
|
|
|
v (Tensor): shape(batch_size, time_steps_k, k_encoder)
|
|
|
|
|
encoder_mask (Tensor): shape(batch_size, time_steps_k)
|
|
|
|
|
encoder_mask (Tensor): shape(batch_size, 1, time_steps_k)
|
|
|
|
|
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q)
|
|
|
|
|
drop_n_heads (int, optional): [description]. Defaults to 0.
|
|
|
|
|
|
|
|
|
@ -290,21 +268,21 @@ class TransformerDecoder(nn.LayerList):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLPPreNet(nn.Layer):
|
|
|
|
|
def __init__(self, d_input, d_hidden, d_output):
|
|
|
|
|
def __init__(self, d_input, d_hidden, d_output, dropout):
|
|
|
|
|
# (lin + relu + dropout) * n + last projection
|
|
|
|
|
super(MLPPreNet, self).__init__()
|
|
|
|
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
|
|
|
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
|
|
|
|
# self.lin3 = nn.Linear(d_output, d_output)
|
|
|
|
|
self.lin3 = nn.Linear(d_output, d_output)
|
|
|
|
|
self.dropout = dropout
|
|
|
|
|
|
|
|
|
|
def forward(self, x, dropout):
|
|
|
|
|
# the original code said also use dropout in inference
|
|
|
|
|
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
|
|
|
|
|
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
|
|
|
|
|
#l3 = self.lin3(l2)
|
|
|
|
|
return l2
|
|
|
|
|
|
|
|
|
|
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
|
|
|
|
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
|
|
|
|
l3 = self.lin3(l2)
|
|
|
|
|
return l3
|
|
|
|
|
|
|
|
|
|
# NOTE: not used in
|
|
|
|
|
class CNNPreNet(nn.Layer):
|
|
|
|
|
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers,
|
|
|
|
|
dropout=0.):
|
|
|
|
@ -347,7 +325,6 @@ class CNNPostNet(nn.Layer):
|
|
|
|
|
# NOTE: it can also be a non-causal conv
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
# why not use pre norms
|
|
|
|
|
x_in = x
|
|
|
|
|
for i, layer in enumerate(self.convs):
|
|
|
|
|
x = layer(x)
|
|
|
|
@ -358,33 +335,60 @@ class CNNPostNet(nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerTTS(nn.Layer):
|
|
|
|
|
def __init__(self, vocab_size, padding_idx, d_encoder, d_decoder, d_mel, n_heads, d_ffn,
|
|
|
|
|
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
|
|
|
|
postnet_kernel_size, max_reduction_factor, dropout):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
frontend: parakeet.frontend.Phonetics,
|
|
|
|
|
d_encoder: int,
|
|
|
|
|
d_decoder: int,
|
|
|
|
|
d_mel: int,
|
|
|
|
|
n_heads: int,
|
|
|
|
|
d_ffn: int,
|
|
|
|
|
encoder_layers: int,
|
|
|
|
|
decoder_layers: int,
|
|
|
|
|
d_prenet: int,
|
|
|
|
|
d_postnet: int,
|
|
|
|
|
postnet_layers: int,
|
|
|
|
|
postnet_kernel_size: int,
|
|
|
|
|
max_reduction_factor: int,
|
|
|
|
|
decoder_prenet_dropout: float,
|
|
|
|
|
dropout: float):
|
|
|
|
|
super(TransformerTTS, self).__init__()
|
|
|
|
|
|
|
|
|
|
# text frontend (text normalization and g2p)
|
|
|
|
|
self.frontend = frontend
|
|
|
|
|
|
|
|
|
|
# encoder
|
|
|
|
|
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05))
|
|
|
|
|
# self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout)
|
|
|
|
|
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
|
|
|
|
|
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
|
|
|
|
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
|
|
|
|
self.encoder_prenet = nn.Embedding(
|
|
|
|
|
frontend.vocab_size, d_encoder,
|
|
|
|
|
padding_idx=frontend.vocab.padding_index,
|
|
|
|
|
weight_attr=I.Uniform(-0.05, 0.05))
|
|
|
|
|
# position encoding matrix may be extended later
|
|
|
|
|
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
|
|
|
|
|
self.encoder_pe_scalar = self.create_parameter(
|
|
|
|
|
[1], attr=I.Constant(1.))
|
|
|
|
|
self.encoder = TransformerEncoder(
|
|
|
|
|
d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder)
|
|
|
|
|
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later
|
|
|
|
|
self.decoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
|
|
|
|
self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder)
|
|
|
|
|
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
|
|
|
|
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
|
|
|
|
|
self.decoder_pe_scalar = self.create_parameter(
|
|
|
|
|
[1], attr=I.Constant(1.))
|
|
|
|
|
self.decoder = TransformerDecoder(
|
|
|
|
|
d_decoder, n_heads, d_ffn, decoder_layers, dropout,
|
|
|
|
|
d_encoder=d_encoder)
|
|
|
|
|
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
|
|
|
|
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
|
|
|
|
self.decoder_postnet = CNNPostNet(
|
|
|
|
|
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
|
|
|
|
self.stop_conditioner = nn.Linear(d_mel, 3)
|
|
|
|
|
|
|
|
|
|
# specs
|
|
|
|
|
self.padding_idx = padding_idx
|
|
|
|
|
self.padding_idx = frontend.vocab.padding_index
|
|
|
|
|
self.d_encoder = d_encoder
|
|
|
|
|
self.d_decoder = d_decoder
|
|
|
|
|
self.d_mel = d_mel
|
|
|
|
|
self.max_r = max_reduction_factor
|
|
|
|
|
self.dropout = dropout
|
|
|
|
|
self.decoder_prenet_dropout = decoder_prenet_dropout
|
|
|
|
|
|
|
|
|
|
# start and end: though it is only used in predict
|
|
|
|
|
# it can also be used in training
|
|
|
|
@ -395,13 +399,19 @@ class TransformerTTS(nn.Layer):
|
|
|
|
|
|
|
|
|
|
# mutables
|
|
|
|
|
self.r = max_reduction_factor # set it every call
|
|
|
|
|
self.decoder_prenet_dropout = 0.0
|
|
|
|
|
self.drop_n_heads = 0
|
|
|
|
|
|
|
|
|
|
def forward(self, text, mel, stop):
|
|
|
|
|
def forward(self, text, mel):
|
|
|
|
|
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
|
|
|
|
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
|
|
|
|
|
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights, stop_logits
|
|
|
|
|
outputs = {
|
|
|
|
|
"mel_output": mel_output,
|
|
|
|
|
"mel_intermediate": mel_intermediate,
|
|
|
|
|
"encoder_attention_weights": encoder_attention_weights,
|
|
|
|
|
"cross_attention_weights": cross_attention_weights,
|
|
|
|
|
"stop_logits": stop_logits,
|
|
|
|
|
}
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def encode(self, text):
|
|
|
|
|
T_enc = text.shape[-1]
|
|
|
|
@ -414,7 +424,8 @@ class TransformerTTS(nn.Layer):
|
|
|
|
|
x = F.dropout(x, self.dropout, training=self.training)
|
|
|
|
|
|
|
|
|
|
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
|
|
|
|
|
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
|
|
|
|
|
encoder_padding_mask = paddle.unsqueeze(
|
|
|
|
|
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1)
|
|
|
|
|
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
|
|
|
|
|
return x, attention_weights, encoder_padding_mask
|
|
|
|
|
|
|
|
|
@ -453,21 +464,26 @@ class TransformerTTS(nn.Layer):
|
|
|
|
|
|
|
|
|
|
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
|
|
|
|
|
|
|
|
|
def predict(self, input, max_length=1000, verbose=True):
|
|
|
|
|
"""[summary]
|
|
|
|
|
def predict(self, input, raw_input=True, max_length=1000, verbose=True):
|
|
|
|
|
"""Predict log scale magnitude mel spectrogram from text input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
input (Tensor): shape (T), dtype int, input text sequencce.
|
|
|
|
|
max_length (int, optional): max decoder steps. Defaults to 1000.
|
|
|
|
|
verbose (bool, optional): display progress bar. Defaults to True.
|
|
|
|
|
"""
|
|
|
|
|
text_input = paddle.unsqueeze(input, 0) # (1, T)
|
|
|
|
|
if raw_input:
|
|
|
|
|
text_ids = paddle.to_tensor(self.frontend(input))
|
|
|
|
|
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
|
|
|
|
else:
|
|
|
|
|
text_input = input
|
|
|
|
|
|
|
|
|
|
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
|
|
|
|
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
|
|
|
|
|
|
|
|
|
# encoder the text sequence
|
|
|
|
|
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
|
|
|
|
|
for _ in range(int(max_length // self.r) + 1):
|
|
|
|
|
for _ in trange(int(max_length // self.r) + 1):
|
|
|
|
|
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
|
|
|
|
encoder_output, decoder_input, encoder_padding_mask)
|
|
|
|
|
|
|
|
|
@ -477,19 +493,23 @@ class TransformerTTS(nn.Layer):
|
|
|
|
|
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
|
|
|
|
|
|
|
|
|
# stop condition?
|
|
|
|
|
if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index:
|
|
|
|
|
if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index):
|
|
|
|
|
if verbose:
|
|
|
|
|
print("Hits stop condition.")
|
|
|
|
|
break
|
|
|
|
|
mel_output = decoder_output[:, 1:, :]
|
|
|
|
|
|
|
|
|
|
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
|
|
|
|
|
outputs = {
|
|
|
|
|
"mel_output": mel_output,
|
|
|
|
|
"encoder_attention_weights": encoder_attentions,
|
|
|
|
|
"cross_attention_weights": cross_attention_weights,
|
|
|
|
|
}
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def set_constants(self, reduction_factor, drop_n_heads, decoder_prenet_dropout):
|
|
|
|
|
# TODO(chenfeiyu): make a good design for these hyperparameter settings
|
|
|
|
|
def set_constants(self, reduction_factor, drop_n_heads):
|
|
|
|
|
self.r = reduction_factor
|
|
|
|
|
self.drop_n_heads = drop_n_heads
|
|
|
|
|
self.decoder_prenet_dropout = decoder_prenet_dropout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerTTSLoss(nn.Layer):
|
|
|
|
|
def __init__(self, stop_loss_scale):
|
|
|
|
@ -505,12 +525,14 @@ class TransformerTTSLoss(nn.Layer):
|
|
|
|
|
mel_len = mask.shape[-1]
|
|
|
|
|
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
|
|
|
|
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
|
|
|
|
|
stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
|
|
|
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
|
|
|
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
|
|
|
|
|
|
|
|
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
|
|
|
|
details = dict(
|
|
|
|
|
losses = dict(
|
|
|
|
|
loss=loss, # total loss
|
|
|
|
|
mel_loss1=mel_loss1, # ouput mel loss
|
|
|
|
|
mel_loss2=mel_loss2, # intermediate mel loss
|
|
|
|
|
stop_loss=stop_loss # stop prob loss
|
|
|
|
|
)
|
|
|
|
|
return loss, details
|
|
|
|
|
return losses
|