ParakeetRebeccaRosario/parakeet/models/transformer_tts.py

363 lines
16 KiB
Python
Raw Normal View History

2020-10-10 15:51:54 +08:00
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
2020-10-10 15:51:54 +08:00
from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention
2020-10-14 10:05:26 +08:00
from parakeet.modules.transformer import PositionwiseFFN
from parakeet.modules import masking
from parakeet.modules.conv import Conv1dBatchNorm
2020-10-14 10:05:26 +08:00
from parakeet.modules import positional_encoding as pe
2020-10-10 15:51:54 +08:00
__all__ = ["TransformerTTS"]
2020-10-10 15:51:54 +08:00
# Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer):
"""
Multihead scaled dot product attention with drop head. See
[Scheduled DropHead: A Regularization Method for Transformer Models](https://arxiv.org/abs/2004.13342)
for details.
Another deviation is that it concats the input query and context vector before
applying the output projection.
"""
def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None):
2020-10-10 15:51:54 +08:00
"""
Args:
model_dim (int): the feature size of query.
num_heads (int): the number of attention heads.
k_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
v_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
Raises:
ValueError: if model_dim is not divisible by num_heads
"""
super(MultiheadAttention, self).__init__()
if model_dim % num_heads !=0:
raise ValueError("model_dim must be divisible by num_heads")
depth = model_dim // num_heads
k_dim = k_dim or depth
v_dim = v_dim or depth
k_input_dim = k_input_dim or model_dim
v_input_dim = v_input_dim or model_dim
2020-10-10 15:51:54 +08:00
self.affine_q = nn.Linear(model_dim, num_heads * k_dim)
self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
2020-10-10 15:51:54 +08:00
self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)
self.num_heads = num_heads
self.model_dim = model_dim
def forward(self, q, k, v, mask, drop_n_heads=0):
"""
Compute context vector and attention weights.
Args:
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
broadcastable shape, dtype: float32 or float64, the mask.
Returns:
(out, attention_weights)
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
q_in = q
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads)
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
concat_feature = paddle.concat([q_in, context_vectors], -1)
out = self.affine_o(concat_feature)
return out, attention_weights
class TransformerEncoderLayer(nn.Layer):
"""
Transformer encoder layer.
"""
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
"""
Args:
d_model (int): the feature size of the input, and the output.
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
d_ffn (int): the hidden size of the internal PositionwiseFFN.
dropout (float, optional): the probability of the dropout in
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
"""
super(TransformerEncoderLayer, self).__init__()
self.self_mha = MultiheadAttention(d_model, n_heads)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
def forward(self, x, mask):
"""
Args:
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
mask (Tensor): shape(batch_size, time_steps), the padding mask.
Returns:
(x, attn_weights)
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
"""
# pre norm
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1))
x = x_in + context_vector # here, the order can be tuned
# pre norm
x = x + self.ffn(self.layer_norm2(x))
return x, attn_weights
class TransformerDecoderLayer(nn.Layer):
"""
Transformer decoder layer.
"""
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
2020-10-10 15:51:54 +08:00
"""
Args:
d_model (int): the feature size of the input, and the output.
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
d_ffn (int): the hidden size of the internal PositionwiseFFN.
dropout (float, optional): the probability of the dropout in
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
"""
super(TransformerDecoderLayer, self).__init__()
self.self_mha = MultiheadAttention(d_model, n_heads)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
2020-10-10 15:51:54 +08:00
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
def forward(self, q, k, v, encoder_mask, decoder_mask):
"""
Args:
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
v (Tensor): shape(batch_size, time_steps_k, d_model), values
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask.
2020-10-15 16:49:14 +08:00
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
2020-10-10 15:51:54 +08:00
Returns:
(q, self_attn_weights, cross_attn_weights)
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
2020-10-15 16:49:14 +08:00
"""
2020-10-10 15:51:54 +08:00
# pre norm
q_in = q
q = self.layer_norm1(q)
2020-10-15 16:49:14 +08:00
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
2020-10-10 15:51:54 +08:00
q = q_in + context_vector
# pre norm
q_in = q
q = self.layer_norm2(q)
context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1))
q = q_in + context_vector
# pre norm
q = q + self.ffn(self.layer_norm3(q))
return q, self_attn_weights, cross_attn_weights
class TransformerEncoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
super(TransformerEncoder, self).__init__()
for _ in range(n_layers):
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
def forward(self, x, mask):
attention_weights = []
for layer in self:
x, attention_weights_i = layer(x, mask)
attention_weights.append(attention_weights_i)
return x, attention_weights
class TransformerDecoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None):
2020-10-10 15:51:54 +08:00
super(TransformerDecoder, self).__init__()
for _ in range(n_layers):
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
2020-10-10 15:51:54 +08:00
2020-10-15 16:49:14 +08:00
def forward(self, q, k, v, encoder_mask, decoder_mask):
2020-10-10 15:51:54 +08:00
self_attention_weights = []
cross_attention_weights = []
for layer in self:
2020-10-15 16:49:14 +08:00
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask)
2020-10-10 15:51:54 +08:00
self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i)
2020-10-15 16:49:14 +08:00
return q, self_attention_weights, cross_attention_weights
2020-10-13 15:20:37 +08:00
2020-10-15 16:49:14 +08:00
class MLPPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, dropout):
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
2020-10-10 15:51:54 +08:00
self.dropout1 = nn.Dropout(dropout)
2020-10-15 16:49:14 +08:00
self.lin2 = nn.Linear(d_hidden, d_output)
2020-10-10 15:51:54 +08:00
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
# the original code said also use dropout in inference
return self.dropout2(F.relu(self.lin2(self.dropout1(F.relu(self.lin1(x))))))
2020-10-15 16:49:14 +08:00
class CNNPostNet(nn.Layer):
2020-10-10 15:51:54 +08:00
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
2020-10-15 16:49:14 +08:00
super(CNNPostNet, self).__init__()
2020-10-10 15:51:54 +08:00
self.convs = nn.LayerList()
2020-10-15 16:49:14 +08:00
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
2020-10-10 15:51:54 +08:00
padding = (kernel_size[0] - 1, 0)
for i in range(n_layers):
c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append(
Conv1dBatchNorm(c_in, c_out, kernel_size, padding=padding))
self.last_norm = nn.BatchNorm1D(d_output)
2020-10-10 15:51:54 +08:00
def forward(self, x):
2020-10-14 10:05:26 +08:00
x_in = x
2020-10-10 15:51:54 +08:00
for layer in self.convs:
x = paddle.tanh(layer(x))
2020-10-14 10:05:26 +08:00
x = self.last_norm(x + x_in)
2020-10-10 15:51:54 +08:00
return x
class TransformerTTS(nn.Layer):
def __init__(self, vocab_size, padding_idx, d_encoder, d_decoder, d_mel, n_heads, d_ffn,
2020-10-10 15:51:54 +08:00
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
2020-10-14 10:05:26 +08:00
postnet_kernel_size, max_reduction_factor, dropout):
2020-10-15 16:49:14 +08:00
super(TransformerTTS, self).__init__()
# initial pe scalar is 1, though it is trainable
self.pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
2020-10-14 10:05:26 +08:00
# encoder
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx)
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
# decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later
self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder)
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
2020-10-15 16:49:14 +08:00
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
self.stop_conditioner = nn.Linear(d_mel, 3)
2020-10-14 10:05:26 +08:00
# specs
self.padding_idx = padding_idx
self.d_encoder = d_encoder
self.d_decoder = d_decoder
2020-10-14 10:05:26 +08:00
# start and end: though it is only used in predict
# it can also be used in training
dtype = paddle.get_default_dtype()
self.start_vec = paddle.full([1, d_mel], 0, dtype=dtype)
self.end_vec = paddle.full([1, d_mel], 0, dtype=dtype)
self.stop_prob_index = 2
self.max_r = max_reduction_factor
self.r = max_reduction_factor # set it every call
2020-10-14 10:05:26 +08:00
def forward(self, text, mel, stop):
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights
2020-10-14 10:05:26 +08:00
def encode(self, text):
T_enc = text.shape[-1]
embed = self.encoder_prenet(text)
pe = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_encoder)) + pe * self.pe_scalar
2020-10-15 16:49:14 +08:00
encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype)
2020-10-14 10:05:26 +08:00
x = F.dropout(x, training=self.training)
x, attention_weights = self.encoder(x, encoder_padding_mask)
return x, attention_weights, encoder_padding_mask
2020-10-10 15:51:54 +08:00
2020-10-15 16:49:14 +08:00
def decode(self, encoder_output, input, encoder_padding_mask):
batch_size, T_dec, mel_dim = input.shape
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
2020-10-14 10:05:26 +08:00
2020-10-15 16:49:14 +08:00
decoder_input = self.decoder_prenet(input)
decoder_output, _, cross_attention_weights = self.decoder(
decoder_input,
encoder_output,
encoder_output,
encoder_padding_mask,
decoder_mask)
output_proj = self.final_proj(decoder_output)
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate)
2020-10-14 10:05:26 +08:00
2020-10-15 16:49:14 +08:00
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1])
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
2020-10-10 15:51:54 +08:00
def predict(self, input, max_length=1000, verbose=True):
"""[summary]
Args:
input (Tensor): shape (T), dtype int, input text sequencce.
max_length (int, optional): max decoder steps. Defaults to 1000.
verbose (bool, optional): display progress bar. Defaults to True.
"""
text_input = paddle.unsqueeze(input, 0) # (1, T)
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
# encoder the text sequence
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
for _ in range(int(max_length // self.r) + 1):
mel_output, _, cross_attention_weights, stop_logits = self.decode(
encoder_output, decoder_input, encoder_padding_mask)
# extract last step and append it to decoder input
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1)
# extract last r steps and append it to decoder output
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition?
if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index:
if verbose:
print("Hits stop condition.")
break
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
class TransformerTTSLoss(nn.Layer):
def __init__(self, stop_loss_scale):
super(TransformerTTSLoss, self).__init__()
self.stop_loss_scale = stop_loss_scale
def forward(self, ):
return loss, details