617 lines
26 KiB
Python
617 lines
26 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
from tqdm import trange
|
|
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
from paddle.nn import initializer as I
|
|
|
|
import parakeet
|
|
from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention
|
|
from parakeet.modules.transformer import PositionwiseFFN
|
|
from parakeet.modules import masking
|
|
from parakeet.modules.conv import Conv1dBatchNorm
|
|
from parakeet.modules import positional_encoding as pe
|
|
from parakeet.modules import losses as L
|
|
from parakeet.utils import checkpoint, scheduler
|
|
|
|
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
|
|
|
|
|
|
# Transformer TTS's own implementation of transformer
|
|
class MultiheadAttention(nn.Layer):
|
|
"""Multihead scaled dot product attention with drop head. See
|
|
[Scheduled DropHead: A Regularization Method for Transformer Models](https://arxiv.org/abs/2004.13342)
|
|
for details.
|
|
|
|
Another deviation is that it concats the input query and context vector before
|
|
applying the output projection.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_dim,
|
|
num_heads,
|
|
k_dim=None,
|
|
v_dim=None,
|
|
k_input_dim=None,
|
|
v_input_dim=None):
|
|
"""
|
|
Args:
|
|
model_dim (int): the feature size of query.
|
|
num_heads (int): the number of attention heads.
|
|
k_dim (int, optional): feature size of the key of each scaled dot
|
|
product attention. If not provided, it is set to
|
|
model_dim / num_heads. Defaults to None.
|
|
v_dim (int, optional): feature size of the key of each scaled dot
|
|
product attention. If not provided, it is set to
|
|
model_dim / num_heads. Defaults to None.
|
|
|
|
Raises:
|
|
ValueError: if model_dim is not divisible by num_heads
|
|
"""
|
|
super(MultiheadAttention, self).__init__()
|
|
if model_dim % num_heads != 0:
|
|
raise ValueError("model_dim must be divisible by num_heads")
|
|
depth = model_dim // num_heads
|
|
k_dim = k_dim or depth
|
|
v_dim = v_dim or depth
|
|
k_input_dim = k_input_dim or model_dim
|
|
v_input_dim = v_input_dim or model_dim
|
|
self.affine_q = nn.Linear(model_dim, num_heads * k_dim)
|
|
self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
|
|
self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
|
|
self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)
|
|
|
|
self.num_heads = num_heads
|
|
self.model_dim = model_dim
|
|
|
|
def forward(self, q, k, v, mask, drop_n_heads=0):
|
|
"""
|
|
Compute context vector and attention weights.
|
|
|
|
Args:
|
|
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
|
|
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
|
|
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
|
|
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
|
|
broadcastable shape, dtype: float32 or float64, the mask.
|
|
|
|
Returns:
|
|
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
|
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
|
"""
|
|
q_in = q
|
|
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
|
k = _split_heads(self.affine_k(k), self.num_heads)
|
|
v = _split_heads(self.affine_v(v), self.num_heads)
|
|
if mask is not None:
|
|
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
|
|
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
|
q, k, v, mask, training=self.training)
|
|
context_vectors = drop_head(context_vectors, drop_n_heads,
|
|
self.training)
|
|
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
|
|
|
concat_feature = paddle.concat([q_in, context_vectors], -1)
|
|
out = self.affine_o(concat_feature)
|
|
return out, attention_weights
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Layer):
|
|
"""
|
|
Transformer encoder layer.
|
|
"""
|
|
|
|
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
|
"""
|
|
Args:
|
|
d_model (int): the feature size of the input, and the output.
|
|
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
|
|
d_ffn (int): the hidden size of the internal PositionwiseFFN.
|
|
dropout (float, optional): the probability of the dropout in
|
|
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
|
|
"""
|
|
super(TransformerEncoderLayer, self).__init__()
|
|
self.self_mha = MultiheadAttention(d_model, n_heads)
|
|
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
|
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
self.dropout = dropout
|
|
|
|
def _forward_mha(self, x, mask, drop_n_heads):
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
x_in = x
|
|
x = self.layer_norm1(x)
|
|
context_vector, attn_weights = self.self_mha(x, x, x, mask,
|
|
drop_n_heads)
|
|
context_vector = x_in + F.dropout(
|
|
context_vector, self.dropout, training=self.training)
|
|
return context_vector, attn_weights
|
|
|
|
def _forward_ffn(self, x):
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
x_in = x
|
|
x = self.layer_norm2(x)
|
|
x = self.ffn(x)
|
|
out = x_in + F.dropout(x, self.dropout, training=self.training)
|
|
return out
|
|
|
|
def forward(self, x, mask, drop_n_heads=0):
|
|
"""
|
|
Args:
|
|
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
|
|
mask (Tensor): shape(batch_size, 1, time_steps), the padding mask.
|
|
|
|
Returns:
|
|
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
|
|
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
|
|
"""
|
|
x, attn_weights = self._forward_mha(x, mask, drop_n_heads)
|
|
x = self._forward_ffn(x)
|
|
return x, attn_weights
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Layer):
|
|
"""
|
|
Transformer decoder layer.
|
|
"""
|
|
|
|
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
|
|
"""
|
|
Args:
|
|
d_model (int): the feature size of the input, and the output.
|
|
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
|
|
d_ffn (int): the hidden size of the internal PositionwiseFFN.
|
|
dropout (float, optional): the probability of the dropout in
|
|
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
|
|
"""
|
|
super(TransformerDecoderLayer, self).__init__()
|
|
self.self_mha = MultiheadAttention(d_model, n_heads)
|
|
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
self.cross_mha = MultiheadAttention(
|
|
d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
|
|
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
|
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
self.dropout = dropout
|
|
|
|
def _forward_self_mha(self, x, mask, drop_n_heads):
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
x_in = x
|
|
x = self.layer_norm1(x)
|
|
context_vector, attn_weights = self.self_mha(x, x, x, mask,
|
|
drop_n_heads)
|
|
context_vector = x_in + F.dropout(
|
|
context_vector, self.dropout, training=self.training)
|
|
return context_vector, attn_weights
|
|
|
|
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
q_in = q
|
|
q = self.layer_norm2(q)
|
|
context_vector, attn_weights = self.cross_mha(q, k, v, mask,
|
|
drop_n_heads)
|
|
context_vector = q_in + F.dropout(
|
|
context_vector, self.dropout, training=self.training)
|
|
return context_vector, attn_weights
|
|
|
|
def _forward_ffn(self, x):
|
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
|
x_in = x
|
|
x = self.layer_norm3(x)
|
|
x = self.ffn(x)
|
|
out = x_in + F.dropout(x, self.dropout, training=self.training)
|
|
return out
|
|
|
|
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
|
"""
|
|
Args:
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
|
|
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
|
|
v (Tensor): shape(batch_size, time_steps_k, d_model), values
|
|
encoder_mask (Tensor): shape(batch_size, 1, time_steps_k) encoder padding mask.
|
|
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
|
|
|
|
Returns:
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
|
|
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
|
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
|
"""
|
|
q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
|
|
drop_n_heads)
|
|
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
|
|
drop_n_heads)
|
|
q = self._forward_ffn(q)
|
|
return q, self_attn_weights, cross_attn_weights
|
|
|
|
|
|
class TransformerEncoder(nn.LayerList):
|
|
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
|
|
super(TransformerEncoder, self).__init__()
|
|
for _ in range(n_layers):
|
|
self.append(
|
|
TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
|
|
|
def forward(self, x, mask, drop_n_heads=0):
|
|
"""
|
|
Args:
|
|
x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor.
|
|
mask (Tensor): shape(batch_size, 1, time_steps), the mask.
|
|
drop_n_heads (int, optional): how many heads to drop. Defaults to 0.
|
|
|
|
Returns:
|
|
x (Tensor): shape(batch_size, time_steps, feature_size), the context vector.
|
|
attention_weights(list[Tensor]), each of shape
|
|
(batch_size, n_heads, time_steps, time_steps), the attention weights.
|
|
"""
|
|
attention_weights = []
|
|
for layer in self:
|
|
x, attention_weights_i = layer(x, mask, drop_n_heads)
|
|
attention_weights.append(attention_weights_i)
|
|
return x, attention_weights
|
|
|
|
|
|
class TransformerDecoder(nn.LayerList):
|
|
def __init__(self,
|
|
d_model,
|
|
n_heads,
|
|
d_ffn,
|
|
n_layers,
|
|
dropout=0.,
|
|
d_encoder=None):
|
|
super(TransformerDecoder, self).__init__()
|
|
for _ in range(n_layers):
|
|
self.append(
|
|
TransformerDecoderLayer(
|
|
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
|
|
|
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
|
"""
|
|
Args:
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model)
|
|
k (Tensor): shape(batch_size, time_steps_k, d_encoder)
|
|
v (Tensor): shape(batch_size, time_steps_k, k_encoder)
|
|
encoder_mask (Tensor): shape(batch_size, 1, time_steps_k)
|
|
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q)
|
|
drop_n_heads (int, optional): [description]. Defaults to 0.
|
|
|
|
Returns:
|
|
q (Tensor): shape(batch_size, time_steps_q, d_model), the output.
|
|
self_attention_weights (List[Tensor]): shape (batch_size, num_heads, encoder_steps, encoder_steps)
|
|
cross_attention_weights (List[Tensor]): shape (batch_size, num_heads, decoder_steps, encoder_steps)
|
|
"""
|
|
self_attention_weights = []
|
|
cross_attention_weights = []
|
|
for layer in self:
|
|
q, self_attention_weights_i, cross_attention_weights_i = layer(
|
|
q, k, v, encoder_mask, decoder_mask, drop_n_heads)
|
|
self_attention_weights.append(self_attention_weights_i)
|
|
cross_attention_weights.append(cross_attention_weights_i)
|
|
return q, self_attention_weights, cross_attention_weights
|
|
|
|
|
|
class MLPPreNet(nn.Layer):
|
|
"""Decoder's prenet."""
|
|
|
|
def __init__(self, d_input, d_hidden, d_output, dropout):
|
|
# (lin + relu + dropout) * n + last projection
|
|
super(MLPPreNet, self).__init__()
|
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
|
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, dropout):
|
|
l1 = F.dropout(
|
|
F.relu(self.lin1(x)), self.dropout, training=self.training)
|
|
l2 = F.dropout(
|
|
F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
|
l3 = self.lin3(l2)
|
|
return l3
|
|
|
|
|
|
class CNNPostNet(nn.Layer):
|
|
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
|
super(CNNPostNet, self).__init__()
|
|
self.convs = nn.LayerList()
|
|
kernel_size = kernel_size if isinstance(kernel_size, (
|
|
tuple, list)) else (kernel_size, )
|
|
padding = (kernel_size[0] - 1, 0)
|
|
for i in range(n_layers):
|
|
c_in = d_input if i == 0 else d_hidden
|
|
c_out = d_output if i == n_layers - 1 else d_hidden
|
|
self.convs.append(
|
|
Conv1dBatchNorm(
|
|
c_in,
|
|
c_out,
|
|
kernel_size,
|
|
weight_attr=I.XavierUniform(),
|
|
padding=padding))
|
|
self.last_bn = nn.BatchNorm1D(d_output)
|
|
# for a layer that ends with a normalization layer that is targeted to
|
|
# output a non zero-central output, it may take a long time to
|
|
# train the scale and bias
|
|
# NOTE: it can also be a non-causal conv
|
|
|
|
def forward(self, x):
|
|
x_in = x
|
|
for i, layer in enumerate(self.convs):
|
|
x = layer(x)
|
|
if i != (len(self.convs) - 1):
|
|
x = F.tanh(x)
|
|
x = self.last_bn(x_in + x)
|
|
return x
|
|
|
|
|
|
class TransformerTTS(nn.Layer):
|
|
def __init__(self,
|
|
frontend: parakeet.frontend.Phonetics,
|
|
d_encoder: int,
|
|
d_decoder: int,
|
|
d_mel: int,
|
|
n_heads: int,
|
|
d_ffn: int,
|
|
encoder_layers: int,
|
|
decoder_layers: int,
|
|
d_prenet: int,
|
|
d_postnet: int,
|
|
postnet_layers: int,
|
|
postnet_kernel_size: int,
|
|
max_reduction_factor: int,
|
|
decoder_prenet_dropout: float,
|
|
dropout: float):
|
|
super(TransformerTTS, self).__init__()
|
|
|
|
# text frontend (text normalization and g2p)
|
|
self.frontend = frontend
|
|
|
|
# encoder
|
|
self.encoder_prenet = nn.Embedding(
|
|
frontend.vocab_size,
|
|
d_encoder,
|
|
padding_idx=frontend.vocab.padding_index,
|
|
weight_attr=I.Uniform(-0.05, 0.05))
|
|
# position encoding matrix may be extended later
|
|
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
|
|
self.encoder_pe_scalar = self.create_parameter(
|
|
[1], attr=I.Constant(1.))
|
|
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
|
|
encoder_layers, dropout)
|
|
|
|
# decoder
|
|
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
|
self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder)
|
|
self.decoder_pe_scalar = self.create_parameter(
|
|
[1], attr=I.Constant(1.))
|
|
self.decoder = TransformerDecoder(
|
|
d_decoder,
|
|
n_heads,
|
|
d_ffn,
|
|
decoder_layers,
|
|
dropout,
|
|
d_encoder=d_encoder)
|
|
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
|
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel,
|
|
postnet_kernel_size, postnet_layers)
|
|
self.stop_conditioner = nn.Linear(d_mel, 3)
|
|
|
|
# specs
|
|
self.padding_idx = frontend.vocab.padding_index
|
|
self.d_encoder = d_encoder
|
|
self.d_decoder = d_decoder
|
|
self.d_mel = d_mel
|
|
self.max_r = max_reduction_factor
|
|
self.dropout = dropout
|
|
self.decoder_prenet_dropout = decoder_prenet_dropout
|
|
|
|
# start and end: though it is only used in predict
|
|
# it can also be used in training
|
|
dtype = paddle.get_default_dtype()
|
|
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
|
|
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
|
|
self.stop_prob_index = 2
|
|
|
|
# mutables
|
|
self.r = max_reduction_factor # set it every call
|
|
self.drop_n_heads = 0
|
|
|
|
def forward(self, text, mel):
|
|
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
|
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
|
|
encoded, mel, encoder_mask)
|
|
outputs = {
|
|
"mel_output": mel_output,
|
|
"mel_intermediate": mel_intermediate,
|
|
"encoder_attention_weights": encoder_attention_weights,
|
|
"cross_attention_weights": cross_attention_weights,
|
|
"stop_logits": stop_logits,
|
|
}
|
|
return outputs
|
|
|
|
def encode(self, text):
|
|
T_enc = text.shape[-1]
|
|
embed = self.encoder_prenet(text)
|
|
if embed.shape[1] > self.encoder_pe.shape[0]:
|
|
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
|
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
|
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
|
|
x = embed.scale(math.sqrt(
|
|
self.d_encoder)) + pos_enc * self.encoder_pe_scalar
|
|
x = F.dropout(x, self.dropout, training=self.training)
|
|
|
|
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
|
|
encoder_padding_mask = paddle.unsqueeze(
|
|
masking.id_mask(
|
|
text, self.padding_idx, dtype=x.dtype), 1)
|
|
x, attention_weights = self.encoder(x, encoder_padding_mask,
|
|
self.drop_n_heads)
|
|
return x, attention_weights, encoder_padding_mask
|
|
|
|
def decode(self, encoder_output, input, encoder_padding_mask):
|
|
batch_size, T_dec, mel_dim = input.shape
|
|
|
|
x = self.decoder_prenet(input, self.decoder_prenet_dropout)
|
|
# twice its length if needed
|
|
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
|
|
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
|
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
|
|
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
|
|
x = x.scale(math.sqrt(
|
|
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
|
x = F.dropout(x, self.dropout, training=self.training)
|
|
|
|
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
|
decoder_padding_mask = masking.feature_mask(
|
|
input, axis=-1, dtype=input.dtype)
|
|
decoder_mask = masking.combine_mask(
|
|
decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
|
decoder_output, _, cross_attention_weights = self.decoder(
|
|
x, encoder_output, encoder_output, encoder_padding_mask,
|
|
decoder_mask, self.drop_n_heads)
|
|
|
|
# use only parts of it
|
|
output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim]
|
|
mel_intermediate = paddle.reshape(output_proj,
|
|
[batch_size, -1, mel_dim])
|
|
stop_logits = self.stop_conditioner(mel_intermediate)
|
|
|
|
# cnn postnet
|
|
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
|
mel_output = self.decoder_postnet(mel_channel_first)
|
|
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
|
|
|
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
|
|
|
@paddle.no_grad()
|
|
def infer(self, input, max_length=1000, verbose=True):
|
|
"""Predict log scale magnitude mel spectrogram from text input.
|
|
|
|
Args:
|
|
input (Tensor): shape (T), dtype int, input text sequencce.
|
|
max_length (int, optional): max decoder steps. Defaults to 1000.
|
|
verbose (bool, optional): display progress bar. Defaults to True.
|
|
"""
|
|
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
|
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
|
|
|
# encoder the text sequence
|
|
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
|
|
input)
|
|
for _ in trange(int(max_length // self.r) + 1):
|
|
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
|
encoder_output, decoder_input, encoder_padding_mask)
|
|
|
|
# extract last step and append it to decoder input
|
|
decoder_input = paddle.concat(
|
|
[decoder_input, mel_output[:, -1:, :]], 1)
|
|
# extract last r steps and append it to decoder output
|
|
decoder_output = paddle.concat(
|
|
[decoder_output, mel_output[:, -self.r:, :]], 1)
|
|
|
|
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
|
if paddle.any(
|
|
paddle.argmax(
|
|
stop_logits[0, -self.r:, :], axis=-1) ==
|
|
self.stop_prob_index):
|
|
if verbose:
|
|
print("Hits stop condition.")
|
|
break
|
|
mel_output = decoder_output[:, 1:, :]
|
|
|
|
outputs = {
|
|
"mel_output": mel_output,
|
|
"encoder_attention_weights": encoder_attentions,
|
|
"cross_attention_weights": cross_attention_weights,
|
|
}
|
|
return outputs
|
|
|
|
@paddle.no_grad()
|
|
def predict(self, input, max_length=1000, verbose=True):
|
|
text_ids = paddle.to_tensor(self.frontend(input))
|
|
input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
|
outputs = self.infer(input, max_length=max_length, verbose=verbose)
|
|
outputs = {k: v[0].numpy() for k, v in outputs.items()}
|
|
return outputs
|
|
|
|
def set_constants(self, reduction_factor, drop_n_heads):
|
|
self.r = reduction_factor
|
|
self.drop_n_heads = drop_n_heads
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, frontend, config, checkpoint_path):
|
|
model = TransformerTTS(
|
|
frontend,
|
|
d_encoder=config.model.d_encoder,
|
|
d_decoder=config.model.d_decoder,
|
|
d_mel=config.data.d_mel,
|
|
n_heads=config.model.n_heads,
|
|
d_ffn=config.model.d_ffn,
|
|
encoder_layers=config.model.encoder_layers,
|
|
decoder_layers=config.model.decoder_layers,
|
|
d_prenet=config.model.d_prenet,
|
|
d_postnet=config.model.d_postnet,
|
|
postnet_layers=config.model.postnet_layers,
|
|
postnet_kernel_size=config.model.postnet_kernel_size,
|
|
max_reduction_factor=config.model.max_reduction_factor,
|
|
decoder_prenet_dropout=config.model.decoder_prenet_dropout,
|
|
dropout=config.model.dropout)
|
|
|
|
iteration = checkpoint.load_parameters(
|
|
model, checkpoint_path=checkpoint_path)
|
|
drop_n_heads = scheduler.StepWise(config.training.drop_n_heads)
|
|
reduction_factor = scheduler.StepWise(config.training.reduction_factor)
|
|
model.set_constants(
|
|
reduction_factor=reduction_factor(iteration),
|
|
drop_n_heads=drop_n_heads(iteration))
|
|
return model
|
|
|
|
|
|
class TransformerTTSLoss(nn.Layer):
|
|
def __init__(self, stop_loss_scale):
|
|
super(TransformerTTSLoss, self).__init__()
|
|
self.stop_loss_scale = stop_loss_scale
|
|
|
|
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
|
|
stop_probs):
|
|
mask = masking.feature_mask(
|
|
mel_target, axis=-1, dtype=mel_target.dtype)
|
|
mask1 = paddle.unsqueeze(mask, -1)
|
|
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
|
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
|
|
|
mel_len = mask.shape[-1]
|
|
last_position = F.one_hot(
|
|
mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
|
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(
|
|
mask.dtype)
|
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
|
|
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
|
losses = dict(
|
|
loss=loss, # total loss
|
|
mel_loss1=mel_loss1, # ouput mel loss
|
|
mel_loss2=mel_loss2, # intermediate mel loss
|
|
stop_loss=stop_loss # stop prob loss
|
|
)
|
|
return losses
|