add network of tacotron2 model

This commit is contained in:
lfchener 2020-12-09 09:08:17 +00:00
parent f255eee029
commit b12eda8423
3 changed files with 372 additions and 173 deletions

View File

@ -1,3 +1,17 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from tqdm import trange
import paddle
@ -15,6 +29,7 @@ from parakeet.modules import losses as L
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
# Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer):
"""
@ -25,7 +40,14 @@ class MultiheadAttention(nn.Layer):
Another deviation is that it concats the input query and context vector before
applying the output projection.
"""
def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None):
def __init__(self,
model_dim,
num_heads,
k_dim=None,
v_dim=None,
k_input_dim=None,
v_input_dim=None):
"""
Args:
model_dim (int): the feature size of query.
@ -41,7 +63,7 @@ class MultiheadAttention(nn.Layer):
ValueError: if model_dim is not divisible by num_heads
"""
super(MultiheadAttention, self).__init__()
if model_dim % num_heads !=0:
if model_dim % num_heads != 0:
raise ValueError("model_dim must be divisible by num_heads")
depth = model_dim // num_heads
k_dim = k_dim or depth
@ -52,10 +74,10 @@ class MultiheadAttention(nn.Layer):
self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)
self.num_heads = num_heads
self.model_dim = model_dim
def forward(self, q, k, v, mask, drop_n_heads=0):
"""
Compute context vector and attention weights.
@ -72,17 +94,18 @@ class MultiheadAttention(nn.Layer):
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
q_in = q
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads)
if mask is not None:
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, training=self.training)
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
context_vectors = drop_head(context_vectors, drop_n_heads,
self.training)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
concat_feature = paddle.concat([q_in, context_vectors], -1)
out = self.affine_o(concat_feature)
return out, attention_weights
@ -92,6 +115,7 @@ class TransformerEncoderLayer(nn.Layer):
"""
Transformer encoder layer.
"""
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
"""
Args:
@ -114,8 +138,10 @@ class TransformerEncoderLayer(nn.Layer):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
context_vector, attn_weights = self.self_mha(x, x, x, mask,
drop_n_heads)
context_vector = x_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
def _forward_ffn(self, x):
@ -123,9 +149,9 @@ class TransformerEncoderLayer(nn.Layer):
x_in = x
x = self.layer_norm2(x)
x = self.ffn(x)
out= x_in + F.dropout(x, self.dropout, training=self.training)
out = x_in + F.dropout(x, self.dropout, training=self.training)
return out
def forward(self, x, mask, drop_n_heads=0):
"""
Args:
@ -145,6 +171,7 @@ class TransformerDecoderLayer(nn.Layer):
"""
Transformer decoder layer.
"""
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
"""
Args:
@ -157,37 +184,42 @@ class TransformerDecoderLayer(nn.Layer):
super(TransformerDecoderLayer, self).__init__()
self.self_mha = MultiheadAttention(d_model, n_heads)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
self.cross_mha = MultiheadAttention(
d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
self.dropout = dropout
def _forward_self_mha(self, x, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
context_vector, attn_weights = self.self_mha(x, x, x, mask,
drop_n_heads)
context_vector = x_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
q_in = q
q = self.layer_norm2(q)
context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads)
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
context_vector, attn_weights = self.cross_mha(q, k, v, mask,
drop_n_heads)
context_vector = q_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
def _forward_ffn(self, x):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x
x = self.layer_norm3(x)
x = self.ffn(x)
out= x_in + F.dropout(x, self.dropout, training=self.training)
out = x_in + F.dropout(x, self.dropout, training=self.training)
return out
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
@ -204,8 +236,10 @@ class TransformerDecoderLayer(nn.Layer):
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
"""
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
drop_n_heads)
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
drop_n_heads)
q = self._forward_ffn(q)
return q, self_attn_weights, cross_attn_weights
@ -214,7 +248,8 @@ class TransformerEncoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
super(TransformerEncoder, self).__init__()
for _ in range(n_layers):
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
self.append(
TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
def forward(self, x, mask, drop_n_heads=0):
"""
@ -236,10 +271,18 @@ class TransformerEncoder(nn.LayerList):
class TransformerDecoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None):
def __init__(self,
d_model,
n_heads,
d_ffn,
n_layers,
dropout=0.,
d_encoder=None):
super(TransformerDecoder, self).__init__()
for _ in range(n_layers):
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
self.append(
TransformerDecoderLayer(
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
"""[summary]
@ -260,7 +303,8 @@ class TransformerDecoder(nn.LayerList):
self_attention_weights = []
cross_attention_weights = []
for layer in self:
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads)
q, self_attention_weights_i, cross_attention_weights_i = layer(
q, k, v, encoder_mask, decoder_mask, drop_n_heads)
self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i)
return q, self_attention_weights, cross_attention_weights
@ -268,6 +312,7 @@ class TransformerDecoder(nn.LayerList):
class MLPPreNet(nn.Layer):
"""Decoder's prenet."""
def __init__(self, d_input, d_hidden, d_output, dropout):
# (lin + relu + dropout) * n + last projection
super(MLPPreNet, self).__init__()
@ -275,16 +320,24 @@ class MLPPreNet(nn.Layer):
self.lin2 = nn.Linear(d_hidden, d_hidden)
self.lin3 = nn.Linear(d_hidden, d_hidden)
self.dropout = dropout
def forward(self, x, dropout):
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
l1 = F.dropout(
F.relu(self.lin1(x)), self.dropout, training=self.training)
l2 = F.dropout(
F.relu(self.lin2(l1)), self.dropout, training=self.training)
l3 = self.lin3(l2)
return l3
# NOTE: not used in
class CNNPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers,
def __init__(self,
d_input,
d_hidden,
d_output,
kernel_size,
n_layers,
dropout=0.):
# (conv + bn + relu + dropout) * n + last projection
super(CNNPreNet, self).__init__()
@ -292,16 +345,21 @@ class CNNPreNet(nn.Layer):
c_in = d_input
for _ in range(n_layers):
self.convs.append(
Conv1dBatchNorm(c_in, d_hidden, kernel_size,
weight_attr=I.XavierUniform(),
padding="same", data_format="NLC"))
Conv1dBatchNorm(
c_in,
d_hidden,
kernel_size,
weight_attr=I.XavierUniform(),
padding="same",
data_format="NLC"))
c_in = d_hidden
self.affine_out = nn.Linear(d_hidden, d_output)
self.dropout = dropout
def forward(self, x):
for layer in self.convs:
x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training)
x = F.dropout(
F.relu(layer(x)), self.dropout, training=self.training)
x = self.affine_out(x)
return x
@ -310,21 +368,25 @@ class CNNPostNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
super(CNNPostNet, self).__init__()
self.convs = nn.LayerList()
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
kernel_size = kernel_size if isinstance(kernel_size, (
tuple, list)) else (kernel_size, )
padding = (kernel_size[0] - 1, 0)
for i in range(n_layers):
c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append(
Conv1dBatchNorm(c_in, c_out, kernel_size,
weight_attr=I.XavierUniform(),
padding=padding))
Conv1dBatchNorm(
c_in,
c_out,
kernel_size,
weight_attr=I.XavierUniform(),
padding=padding))
self.last_bn = nn.BatchNorm1D(d_output)
# for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to
# train the scale and bias
# NOTE: it can also be a non-causal conv
def forward(self, x):
x_in = x
for i, layer in enumerate(self.convs):
@ -336,19 +398,19 @@ class CNNPostNet(nn.Layer):
class TransformerTTS(nn.Layer):
def __init__(self,
frontend: parakeet.frontend.Phonetics,
d_encoder: int,
d_decoder: int,
d_mel: int,
def __init__(self,
frontend: parakeet.frontend.Phonetics,
d_encoder: int,
d_decoder: int,
d_mel: int,
n_heads: int,
d_ffn: int,
encoder_layers: int,
decoder_layers: int,
d_prenet: int,
d_postnet: int,
postnet_layers: int,
postnet_kernel_size: int,
encoder_layers: int,
decoder_layers: int,
d_prenet: int,
d_postnet: int,
postnet_layers: int,
postnet_kernel_size: int,
max_reduction_factor: int,
decoder_prenet_dropout: float,
dropout: float):
@ -359,29 +421,34 @@ class TransformerTTS(nn.Layer):
# encoder
self.encoder_prenet = nn.Embedding(
frontend.vocab_size, d_encoder,
padding_idx=frontend.vocab.padding_index,
frontend.vocab_size,
d_encoder,
padding_idx=frontend.vocab.padding_index,
weight_attr=I.Uniform(-0.05, 0.05))
# position encoding matrix may be extended later
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(
d_encoder, n_heads, d_ffn, encoder_layers, dropout)
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
encoder_layers, dropout)
# decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
self.decoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.decoder = TransformerDecoder(
d_decoder, n_heads, d_ffn, decoder_layers, dropout,
d_decoder,
n_heads,
d_ffn,
decoder_layers,
dropout,
d_encoder=d_encoder)
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
self.decoder_postnet = CNNPostNet(
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel,
postnet_kernel_size, postnet_layers)
self.stop_conditioner = nn.Linear(d_mel, 3)
# specs
self.padding_idx = frontend.vocab.padding_index
self.d_encoder = d_encoder
@ -390,21 +457,22 @@ class TransformerTTS(nn.Layer):
self.max_r = max_reduction_factor
self.dropout = dropout
self.decoder_prenet_dropout = decoder_prenet_dropout
# start and end: though it is only used in predict
# it can also be used in training
dtype = paddle.get_default_dtype()
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
self.stop_prob_index = 2
# mutables
self.r = max_reduction_factor # set it every call
self.r = max_reduction_factor # set it every call
self.drop_n_heads = 0
def forward(self, text, mel):
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
encoded, mel, encoder_mask)
outputs = {
"mel_output": mel_output,
"mel_intermediate": mel_intermediate,
@ -420,51 +488,54 @@ class TransformerTTS(nn.Layer):
if embed.shape[1] > self.encoder_pe.shape[0]:
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(
self.d_encoder)) + pos_enc * self.encoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training)
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
encoder_padding_mask = paddle.unsqueeze(
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1)
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
masking.id_mask(
text, self.padding_idx, dtype=x.dtype), 1)
x, attention_weights = self.encoder(x, encoder_padding_mask,
self.drop_n_heads)
return x, attention_weights, encoder_padding_mask
def decode(self, encoder_output, input, encoder_padding_mask):
batch_size, T_dec, mel_dim = input.shape
x = self.decoder_prenet(input, self.decoder_prenet_dropout)
# twice its length if needed
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
pos_enc = self.decoder_pe[:T_dec*self.r:self.r, :]
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
x = x.scale(math.sqrt(
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training)
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
decoder_padding_mask = masking.feature_mask(
input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(
decoder_padding_mask.unsqueeze(-1), no_future_mask)
decoder_output, _, cross_attention_weights = self.decoder(
x,
encoder_output,
encoder_output,
encoder_padding_mask,
decoder_mask,
self.drop_n_heads)
x, encoder_output, encoder_output, encoder_padding_mask,
decoder_mask, self.drop_n_heads)
# use only parts of it
output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim]
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim]
mel_intermediate = paddle.reshape(output_proj,
[batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate)
# cnn postnet
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1])
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
def predict(self, input, raw_input=True, max_length=1000, verbose=True):
"""Predict log scale magnitude mel spectrogram from text input.
@ -475,26 +546,32 @@ class TransformerTTS(nn.Layer):
"""
if raw_input:
text_ids = paddle.to_tensor(self.frontend(input))
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
else:
text_input = input
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
# encoder the text sequence
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
for _ in trange(int(max_length // self.r) + 1):
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
text_input)
for _ in range(int(max_length // self.r) + 1):
mel_output, _, cross_attention_weights, stop_logits = self.decode(
encoder_output, decoder_input, encoder_padding_mask)
# extract last step and append it to decoder input
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1)
decoder_input = paddle.concat(
[decoder_input, mel_output[:, -1:, :]], 1)
# extract last r steps and append it to decoder output
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
decoder_output = paddle.concat(
[decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index):
if paddle.any(
paddle.argmax(
stop_logits[0, -self.r:, :], axis=-1) ==
self.stop_prob_index):
if verbose:
print("Hits stop condition.")
break
@ -516,24 +593,28 @@ class TransformerTTSLoss(nn.Layer):
def __init__(self, stop_loss_scale):
super(TransformerTTSLoss, self).__init__()
self.stop_loss_scale = stop_loss_scale
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
stop_probs):
mask = masking.feature_mask(
mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
mel_len = mask.shape[-1]
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
last_position = F.one_hot(
mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(
mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss
loss = mel_loss1 + mel_loss2 + stop_loss
losses = dict(
loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss
loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss
)
return losses
@ -542,26 +623,29 @@ class TransformerTTSLoss(nn.Layer):
class AdaptiveTransformerTTSLoss(nn.Layer):
def __init__(self):
super(AdaptiveTransformerTTSLoss, self).__init__()
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
stop_probs):
mask = masking.feature_mask(
mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
batch_size, mel_len = mask.shape
valid_lengths = mask.sum(-1).astype("int64")
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
stop_loss_scale = valid_lengths.sum() / batch_size - 1
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(
mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss
loss = mel_loss1 + mel_loss2 + stop_loss
losses = dict(
loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss
loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss
)
return losses
return losses

View File

@ -1,10 +1,30 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True):
def scaled_dot_product_attention(q,
k,
v,
mask=None,
dropout=0.0,
training=True):
"""
scaled dot product attention with mask. Assume q, k, v all have the same
leader dimensions(denoted as * in descriptions below). Dropout is applied to
@ -22,18 +42,19 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
out (Tensor): shape(*, T_q, d_v), the context vector.
attn_weights (Tensor): shape(*, T_q, T_k), the attention weights.
"""
d = q.shape[-1] # we only support imperative execution
d = q.shape[-1] # we only support imperative execution
qk = paddle.matmul(q, k, transpose_y=True)
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
if mask is not None:
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
attn_weights = F.softmax(scaled_logit, axis=-1)
attn_weights = F.dropout(attn_weights, dropout, training=training)
out = paddle.matmul(attn_weights, v)
return out, attn_weights
def drop_head(x, drop_n_heads, training):
"""
Drop n heads from multiple context vectors.
@ -48,12 +69,12 @@ def drop_head(x, drop_n_heads, training):
"""
if not training or (drop_n_heads == 0):
return x
batch_size, num_heads, _, _ = x.shape
# drop all heads
if num_heads == drop_n_heads:
return paddle.zeros_like(x)
mask = np.ones([batch_size, num_heads])
mask[:, :drop_n_heads] = 0
for subarray in mask:
@ -63,18 +84,21 @@ def drop_head(x, drop_n_heads, training):
out = x * paddle.to_tensor(mask)
return out
def _split_heads(x, num_heads):
batch_size, time_steps, _ = x.shape
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
x = paddle.transpose(x, [0, 2, 1, 3])
return x
def _concat_heads(x):
batch_size, _, time_steps, _ = x.shape
x = paddle.transpose(x, [0, 2, 1, 3])
x = paddle.reshape(x, [batch_size, time_steps, -1])
return x
# Standard implementations of Monohead Attention & Multihead Attention
class MonoheadAttention(nn.Layer):
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
@ -99,10 +123,10 @@ class MonoheadAttention(nn.Layer):
self.affine_k = nn.Linear(model_dim, k_dim)
self.affine_v = nn.Linear(model_dim, v_dim)
self.affine_o = nn.Linear(v_dim, model_dim)
self.model_dim = model_dim
self.dropout = dropout
def forward(self, q, k, v, mask):
"""
Compute context vector and attention weights.
@ -119,22 +143,28 @@ class MonoheadAttention(nn.Layer):
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
q = self.affine_q(q) # (B, T, C)
q = self.affine_q(q) # (B, T, C)
k = self.affine_k(k)
v = self.affine_v(v)
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout, self.training)
out = self.affine_o(context_vectors)
return out, attention_weights
class MultiheadAttention(nn.Layer):
"""
Multihead scaled dot product attention.
"""
def __init__(self, model_dim, num_heads, dropout=0.0, k_dim=None, v_dim=None):
def __init__(self,
model_dim,
num_heads,
dropout=0.0,
k_dim=None,
v_dim=None):
"""
Multihead Attention module.
@ -154,7 +184,7 @@ class MultiheadAttention(nn.Layer):
ValueError: if model_dim is not divisible by num_heads
"""
super(MultiheadAttention, self).__init__()
if model_dim % num_heads !=0:
if model_dim % num_heads != 0:
raise ValueError("model_dim must be divisible by num_heads")
depth = model_dim // num_heads
k_dim = k_dim or depth
@ -163,11 +193,11 @@ class MultiheadAttention(nn.Layer):
self.affine_k = nn.Linear(model_dim, num_heads * k_dim)
self.affine_v = nn.Linear(model_dim, num_heads * v_dim)
self.affine_o = nn.Linear(num_heads * v_dim, model_dim)
self.num_heads = num_heads
self.model_dim = model_dim
self.dropout = dropout
def forward(self, q, k, v, mask):
"""
Compute context vector and attention weights.
@ -184,14 +214,67 @@ class MultiheadAttention(nn.Layer):
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads)
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout, self.training)
# NOTE: there is more sophisticated implementation: Scheduled DropHead
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
out = self.affine_o(context_vectors)
return out, attention_weights
class LocationSensitiveAttention(nn.Layer):
def __init__(self,
d_query: int,
d_key: int,
d_attention: int,
location_filters: int,
location_kernel_size: int):
super().__init__()
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
self.value = nn.Linear(d_attention, 1, bias_attr=False)
#Location Layer
self.location_conv = nn.Conv1D(
2,
location_filters,
location_kernel_size,
1,
int((location_kernel_size - 1) / 2),
1,
bias_attr=False,
data_format='NLC')
self.location_layer = nn.Linear(
location_filters, d_attention, bias_attr=False)
def forward(self,
query,
processed_key,
value,
attention_weights_cat,
mask=None):
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
processed_attention_weights = self.location_layer(
self.location_conv(attention_weights_cat))
alignment = self.value(
paddle.tanh(processed_attention_weights + processed_key +
processed_query))
if mask is not None:
alignment = alignment + (1.0 - mask) * -1e9
attention_weights = F.softmax(alignment, axis=1)
attention_context = paddle.matmul(
attention_weights, value, transpose_x=True)
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
attention_context = paddle.squeeze(attention_context, axis=[1])
return attention_context, attention_weights

View File

@ -1,6 +1,21 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
class Conv1dCell(nn.Conv1D):
"""
A subclass of Conv1d layer, which can be used like an RNN cell. It can take
@ -14,43 +29,46 @@ class Conv1dCell(nn.Conv1D):
As a result, these arguments are removed form the initializer.
"""
def __init__(self,
def __init__(self,
in_channels,
out_channels,
kernel_size,
dilation=1,
weight_attr=None,
bias_attr=None):
_dilation = dilation[0] if isinstance(dilation, (tuple, list)) else dilation
_kernel_size = kernel_size[0] if isinstance(kernel_size, (tuple, list)) else kernel_size
_dilation = dilation[0] if isinstance(dilation,
(tuple, list)) else dilation
_kernel_size = kernel_size[0] if isinstance(kernel_size, (
tuple, list)) else kernel_size
self._r = 1 + (_kernel_size - 1) * _dilation
super(Conv1dCell, self).__init__(
in_channels,
out_channels,
kernel_size,
padding=(self._r - 1, 0),
dilation=dilation,
weight_attr=weight_attr,
bias_attr=bias_attr,
in_channels,
out_channels,
kernel_size,
padding=(self._r - 1, 0),
dilation=dilation,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format="NCL")
@property
def receptive_field(self):
return self._r
def start_sequence(self):
if self.training:
raise Exception("only use start_sequence in evaluation")
self._buffer = None
self._reshaped_weight = paddle.reshape(
self.weight, (self._out_channels, -1))
self._reshaped_weight = paddle.reshape(self.weight,
(self._out_channels, -1))
def initialize_buffer(self, x_t):
batch_size, _ = x_t.shape
self._buffer = paddle.zeros(
(batch_size, self._in_channels, self.receptive_field),
(batch_size, self._in_channels, self.receptive_field),
dtype=x_t.dtype)
def update_buffer(self, x_t):
self._buffer = paddle.concat(
[self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1)
@ -66,7 +84,7 @@ class Conv1dCell(nn.Conv1D):
if self.receptive_field > 1:
if self._buffer is None:
self.initialize_buffer(x_t)
# update buffer
self.update_buffer(x_t)
if self._dilation[0] > 1:
@ -82,20 +100,34 @@ class Conv1dCell(nn.Conv1D):
class Conv1dBatchNorm(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
weight_attr=None, bias_attr=None, data_format="NCL"):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
weight_attr=None,
bias_attr=None,
data_format="NCL",
momentum=0.9,
epsilon=1e-05):
super(Conv1dBatchNorm, self).__init__()
# TODO(chenfeiyu): carefully initialize Conv1d's weight
self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride,
padding=padding,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
# TODO: channel last, but BatchNorm1d does not support channel last layout
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format)
self.conv = nn.Conv1D(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
self.bn = nn.BatchNorm1D(
out_channels,
momentum=momentum,
epsilon=epsilon,
data_format=data_format)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x