add network of tacotron2 model
This commit is contained in:
parent
f255eee029
commit
b12eda8423
|
@ -1,3 +1,17 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from tqdm import trange
|
||||
import paddle
|
||||
|
@ -15,6 +29,7 @@ from parakeet.modules import losses as L
|
|||
|
||||
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
|
||||
|
||||
|
||||
# Transformer TTS's own implementation of transformer
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""
|
||||
|
@ -25,7 +40,14 @@ class MultiheadAttention(nn.Layer):
|
|||
Another deviation is that it concats the input query and context vector before
|
||||
applying the output projection.
|
||||
"""
|
||||
def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None):
|
||||
|
||||
def __init__(self,
|
||||
model_dim,
|
||||
num_heads,
|
||||
k_dim=None,
|
||||
v_dim=None,
|
||||
k_input_dim=None,
|
||||
v_input_dim=None):
|
||||
"""
|
||||
Args:
|
||||
model_dim (int): the feature size of query.
|
||||
|
@ -80,7 +102,8 @@ class MultiheadAttention(nn.Layer):
|
|||
|
||||
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||
q, k, v, mask, training=self.training)
|
||||
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
|
||||
context_vectors = drop_head(context_vectors, drop_n_heads,
|
||||
self.training)
|
||||
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
||||
|
||||
concat_feature = paddle.concat([q_in, context_vectors], -1)
|
||||
|
@ -92,6 +115,7 @@ class TransformerEncoderLayer(nn.Layer):
|
|||
"""
|
||||
Transformer encoder layer.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
||||
"""
|
||||
Args:
|
||||
|
@ -114,8 +138,10 @@ class TransformerEncoderLayer(nn.Layer):
|
|||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||
x_in = x
|
||||
x = self.layer_norm1(x)
|
||||
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
||||
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||
context_vector, attn_weights = self.self_mha(x, x, x, mask,
|
||||
drop_n_heads)
|
||||
context_vector = x_in + F.dropout(
|
||||
context_vector, self.dropout, training=self.training)
|
||||
return context_vector, attn_weights
|
||||
|
||||
def _forward_ffn(self, x):
|
||||
|
@ -145,6 +171,7 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
"""
|
||||
Transformer decoder layer.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
|
||||
"""
|
||||
Args:
|
||||
|
@ -158,7 +185,8 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
self.self_mha = MultiheadAttention(d_model, n_heads)
|
||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||
|
||||
self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
|
||||
self.cross_mha = MultiheadAttention(
|
||||
d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
|
||||
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||
|
||||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||
|
@ -170,16 +198,20 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||
x_in = x
|
||||
x = self.layer_norm1(x)
|
||||
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
||||
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||
context_vector, attn_weights = self.self_mha(x, x, x, mask,
|
||||
drop_n_heads)
|
||||
context_vector = x_in + F.dropout(
|
||||
context_vector, self.dropout, training=self.training)
|
||||
return context_vector, attn_weights
|
||||
|
||||
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
|
||||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||
q_in = q
|
||||
q = self.layer_norm2(q)
|
||||
context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads)
|
||||
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||
context_vector, attn_weights = self.cross_mha(q, k, v, mask,
|
||||
drop_n_heads)
|
||||
context_vector = q_in + F.dropout(
|
||||
context_vector, self.dropout, training=self.training)
|
||||
return context_vector, attn_weights
|
||||
|
||||
def _forward_ffn(self, x):
|
||||
|
@ -204,8 +236,10 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
||||
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
||||
"""
|
||||
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
|
||||
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
|
||||
q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
|
||||
drop_n_heads)
|
||||
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
|
||||
drop_n_heads)
|
||||
q = self._forward_ffn(q)
|
||||
return q, self_attn_weights, cross_attn_weights
|
||||
|
||||
|
@ -214,7 +248,8 @@ class TransformerEncoder(nn.LayerList):
|
|||
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
for _ in range(n_layers):
|
||||
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
||||
self.append(
|
||||
TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
||||
|
||||
def forward(self, x, mask, drop_n_heads=0):
|
||||
"""
|
||||
|
@ -236,10 +271,18 @@ class TransformerEncoder(nn.LayerList):
|
|||
|
||||
|
||||
class TransformerDecoder(nn.LayerList):
|
||||
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None):
|
||||
def __init__(self,
|
||||
d_model,
|
||||
n_heads,
|
||||
d_ffn,
|
||||
n_layers,
|
||||
dropout=0.,
|
||||
d_encoder=None):
|
||||
super(TransformerDecoder, self).__init__()
|
||||
for _ in range(n_layers):
|
||||
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
||||
self.append(
|
||||
TransformerDecoderLayer(
|
||||
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
||||
|
||||
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
||||
"""[summary]
|
||||
|
@ -260,7 +303,8 @@ class TransformerDecoder(nn.LayerList):
|
|||
self_attention_weights = []
|
||||
cross_attention_weights = []
|
||||
for layer in self:
|
||||
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads)
|
||||
q, self_attention_weights_i, cross_attention_weights_i = layer(
|
||||
q, k, v, encoder_mask, decoder_mask, drop_n_heads)
|
||||
self_attention_weights.append(self_attention_weights_i)
|
||||
cross_attention_weights.append(cross_attention_weights_i)
|
||||
return q, self_attention_weights, cross_attention_weights
|
||||
|
@ -268,6 +312,7 @@ class TransformerDecoder(nn.LayerList):
|
|||
|
||||
class MLPPreNet(nn.Layer):
|
||||
"""Decoder's prenet."""
|
||||
|
||||
def __init__(self, d_input, d_hidden, d_output, dropout):
|
||||
# (lin + relu + dropout) * n + last projection
|
||||
super(MLPPreNet, self).__init__()
|
||||
|
@ -277,14 +322,22 @@ class MLPPreNet(nn.Layer):
|
|||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, dropout):
|
||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||
l1 = F.dropout(
|
||||
F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||
l2 = F.dropout(
|
||||
F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||
l3 = self.lin3(l2)
|
||||
return l3
|
||||
|
||||
|
||||
# NOTE: not used in
|
||||
class CNNPreNet(nn.Layer):
|
||||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers,
|
||||
def __init__(self,
|
||||
d_input,
|
||||
d_hidden,
|
||||
d_output,
|
||||
kernel_size,
|
||||
n_layers,
|
||||
dropout=0.):
|
||||
# (conv + bn + relu + dropout) * n + last projection
|
||||
super(CNNPreNet, self).__init__()
|
||||
|
@ -292,16 +345,21 @@ class CNNPreNet(nn.Layer):
|
|||
c_in = d_input
|
||||
for _ in range(n_layers):
|
||||
self.convs.append(
|
||||
Conv1dBatchNorm(c_in, d_hidden, kernel_size,
|
||||
Conv1dBatchNorm(
|
||||
c_in,
|
||||
d_hidden,
|
||||
kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding="same", data_format="NLC"))
|
||||
padding="same",
|
||||
data_format="NLC"))
|
||||
c_in = d_hidden
|
||||
self.affine_out = nn.Linear(d_hidden, d_output)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.convs:
|
||||
x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training)
|
||||
x = F.dropout(
|
||||
F.relu(layer(x)), self.dropout, training=self.training)
|
||||
x = self.affine_out(x)
|
||||
return x
|
||||
|
||||
|
@ -310,13 +368,17 @@ class CNNPostNet(nn.Layer):
|
|||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
||||
super(CNNPostNet, self).__init__()
|
||||
self.convs = nn.LayerList()
|
||||
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
|
||||
kernel_size = kernel_size if isinstance(kernel_size, (
|
||||
tuple, list)) else (kernel_size, )
|
||||
padding = (kernel_size[0] - 1, 0)
|
||||
for i in range(n_layers):
|
||||
c_in = d_input if i == 0 else d_hidden
|
||||
c_out = d_output if i == n_layers - 1 else d_hidden
|
||||
self.convs.append(
|
||||
Conv1dBatchNorm(c_in, c_out, kernel_size,
|
||||
Conv1dBatchNorm(
|
||||
c_in,
|
||||
c_out,
|
||||
kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding=padding))
|
||||
self.last_bn = nn.BatchNorm1D(d_output)
|
||||
|
@ -359,15 +421,16 @@ class TransformerTTS(nn.Layer):
|
|||
|
||||
# encoder
|
||||
self.encoder_prenet = nn.Embedding(
|
||||
frontend.vocab_size, d_encoder,
|
||||
frontend.vocab_size,
|
||||
d_encoder,
|
||||
padding_idx=frontend.vocab.padding_index,
|
||||
weight_attr=I.Uniform(-0.05, 0.05))
|
||||
# position encoding matrix may be extended later
|
||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
|
||||
self.encoder_pe_scalar = self.create_parameter(
|
||||
[1], attr=I.Constant(1.))
|
||||
self.encoder = TransformerEncoder(
|
||||
d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
||||
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
|
||||
encoder_layers, dropout)
|
||||
|
||||
# decoder
|
||||
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
||||
|
@ -375,11 +438,15 @@ class TransformerTTS(nn.Layer):
|
|||
self.decoder_pe_scalar = self.create_parameter(
|
||||
[1], attr=I.Constant(1.))
|
||||
self.decoder = TransformerDecoder(
|
||||
d_decoder, n_heads, d_ffn, decoder_layers, dropout,
|
||||
d_decoder,
|
||||
n_heads,
|
||||
d_ffn,
|
||||
decoder_layers,
|
||||
dropout,
|
||||
d_encoder=d_encoder)
|
||||
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
||||
self.decoder_postnet = CNNPostNet(
|
||||
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel,
|
||||
postnet_kernel_size, postnet_layers)
|
||||
self.stop_conditioner = nn.Linear(d_mel, 3)
|
||||
|
||||
# specs
|
||||
|
@ -404,7 +471,8 @@ class TransformerTTS(nn.Layer):
|
|||
|
||||
def forward(self, text, mel):
|
||||
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
||||
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
|
||||
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
|
||||
encoded, mel, encoder_mask)
|
||||
outputs = {
|
||||
"mel_output": mel_output,
|
||||
"mel_intermediate": mel_intermediate,
|
||||
|
@ -421,13 +489,16 @@ class TransformerTTS(nn.Layer):
|
|||
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
||||
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
||||
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
|
||||
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar
|
||||
x = embed.scale(math.sqrt(
|
||||
self.d_encoder)) + pos_enc * self.encoder_pe_scalar
|
||||
x = F.dropout(x, self.dropout, training=self.training)
|
||||
|
||||
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
|
||||
encoder_padding_mask = paddle.unsqueeze(
|
||||
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1)
|
||||
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
|
||||
masking.id_mask(
|
||||
text, self.padding_idx, dtype=x.dtype), 1)
|
||||
x, attention_weights = self.encoder(x, encoder_padding_mask,
|
||||
self.drop_n_heads)
|
||||
return x, attention_weights, encoder_padding_mask
|
||||
|
||||
def decode(self, encoder_output, input, encoder_padding_mask):
|
||||
|
@ -439,23 +510,23 @@ class TransformerTTS(nn.Layer):
|
|||
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
||||
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
|
||||
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
|
||||
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
||||
x = x.scale(math.sqrt(
|
||||
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
||||
x = F.dropout(x, self.dropout, training=self.training)
|
||||
|
||||
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
||||
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
|
||||
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
||||
decoder_padding_mask = masking.feature_mask(
|
||||
input, axis=-1, dtype=input.dtype)
|
||||
decoder_mask = masking.combine_mask(
|
||||
decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
||||
decoder_output, _, cross_attention_weights = self.decoder(
|
||||
x,
|
||||
encoder_output,
|
||||
encoder_output,
|
||||
encoder_padding_mask,
|
||||
decoder_mask,
|
||||
self.drop_n_heads)
|
||||
x, encoder_output, encoder_output, encoder_padding_mask,
|
||||
decoder_mask, self.drop_n_heads)
|
||||
|
||||
# use only parts of it
|
||||
output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim]
|
||||
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
||||
mel_intermediate = paddle.reshape(output_proj,
|
||||
[batch_size, -1, mel_dim])
|
||||
stop_logits = self.stop_conditioner(mel_intermediate)
|
||||
|
||||
# cnn postnet
|
||||
|
@ -483,18 +554,24 @@ class TransformerTTS(nn.Layer):
|
|||
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||
|
||||
# encoder the text sequence
|
||||
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
|
||||
for _ in trange(int(max_length // self.r) + 1):
|
||||
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
|
||||
text_input)
|
||||
for _ in range(int(max_length // self.r) + 1):
|
||||
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
||||
encoder_output, decoder_input, encoder_padding_mask)
|
||||
|
||||
# extract last step and append it to decoder input
|
||||
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1)
|
||||
decoder_input = paddle.concat(
|
||||
[decoder_input, mel_output[:, -1:, :]], 1)
|
||||
# extract last r steps and append it to decoder output
|
||||
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||
decoder_output = paddle.concat(
|
||||
[decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||
|
||||
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
||||
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index):
|
||||
if paddle.any(
|
||||
paddle.argmax(
|
||||
stop_logits[0, -self.r:, :], axis=-1) ==
|
||||
self.stop_prob_index):
|
||||
if verbose:
|
||||
print("Hits stop condition.")
|
||||
break
|
||||
|
@ -517,15 +594,19 @@ class TransformerTTSLoss(nn.Layer):
|
|||
super(TransformerTTSLoss, self).__init__()
|
||||
self.stop_loss_scale = stop_loss_scale
|
||||
|
||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
||||
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
|
||||
stop_probs):
|
||||
mask = masking.feature_mask(
|
||||
mel_target, axis=-1, dtype=mel_target.dtype)
|
||||
mask1 = paddle.unsqueeze(mask, -1)
|
||||
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||
|
||||
mel_len = mask.shape[-1]
|
||||
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
||||
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
|
||||
last_position = F.one_hot(
|
||||
mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
||||
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(
|
||||
mask.dtype)
|
||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||
|
||||
|
@ -543,8 +624,10 @@ class AdaptiveTransformerTTSLoss(nn.Layer):
|
|||
def __init__(self):
|
||||
super(AdaptiveTransformerTTSLoss, self).__init__()
|
||||
|
||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
||||
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
|
||||
stop_probs):
|
||||
mask = masking.feature_mask(
|
||||
mel_target, axis=-1, dtype=mel_target.dtype)
|
||||
mask1 = paddle.unsqueeze(mask, -1)
|
||||
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||
|
@ -553,7 +636,8 @@ class AdaptiveTransformerTTSLoss(nn.Layer):
|
|||
valid_lengths = mask.sum(-1).astype("int64")
|
||||
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
||||
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
||||
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
|
||||
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(
|
||||
mask.dtype)
|
||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||
|
||||
|
|
|
@ -1,10 +1,30 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True):
|
||||
|
||||
def scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
mask=None,
|
||||
dropout=0.0,
|
||||
training=True):
|
||||
"""
|
||||
scaled dot product attention with mask. Assume q, k, v all have the same
|
||||
leader dimensions(denoted as * in descriptions below). Dropout is applied to
|
||||
|
@ -34,6 +54,7 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
|
|||
out = paddle.matmul(attn_weights, v)
|
||||
return out, attn_weights
|
||||
|
||||
|
||||
def drop_head(x, drop_n_heads, training):
|
||||
"""
|
||||
Drop n heads from multiple context vectors.
|
||||
|
@ -63,18 +84,21 @@ def drop_head(x, drop_n_heads, training):
|
|||
out = x * paddle.to_tensor(mask)
|
||||
return out
|
||||
|
||||
|
||||
def _split_heads(x, num_heads):
|
||||
batch_size, time_steps, _ = x.shape
|
||||
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
|
||||
x = paddle.transpose(x, [0, 2, 1, 3])
|
||||
return x
|
||||
|
||||
|
||||
def _concat_heads(x):
|
||||
batch_size, _, time_steps, _ = x.shape
|
||||
x = paddle.transpose(x, [0, 2, 1, 3])
|
||||
x = paddle.reshape(x, [batch_size, time_steps, -1])
|
||||
return x
|
||||
|
||||
|
||||
# Standard implementations of Monohead Attention & Multihead Attention
|
||||
class MonoheadAttention(nn.Layer):
|
||||
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
|
||||
|
@ -134,7 +158,13 @@ class MultiheadAttention(nn.Layer):
|
|||
"""
|
||||
Multihead scaled dot product attention.
|
||||
"""
|
||||
def __init__(self, model_dim, num_heads, dropout=0.0, k_dim=None, v_dim=None):
|
||||
|
||||
def __init__(self,
|
||||
model_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
k_dim=None,
|
||||
v_dim=None):
|
||||
"""
|
||||
Multihead Attention module.
|
||||
|
||||
|
@ -195,3 +225,56 @@ class MultiheadAttention(nn.Layer):
|
|||
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
||||
out = self.affine_o(context_vectors)
|
||||
return out, attention_weights
|
||||
|
||||
|
||||
class LocationSensitiveAttention(nn.Layer):
|
||||
def __init__(self,
|
||||
d_query: int,
|
||||
d_key: int,
|
||||
d_attention: int,
|
||||
location_filters: int,
|
||||
location_kernel_size: int):
|
||||
super().__init__()
|
||||
|
||||
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
|
||||
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
|
||||
self.value = nn.Linear(d_attention, 1, bias_attr=False)
|
||||
|
||||
#Location Layer
|
||||
self.location_conv = nn.Conv1D(
|
||||
2,
|
||||
location_filters,
|
||||
location_kernel_size,
|
||||
1,
|
||||
int((location_kernel_size - 1) / 2),
|
||||
1,
|
||||
bias_attr=False,
|
||||
data_format='NLC')
|
||||
self.location_layer = nn.Linear(
|
||||
location_filters, d_attention, bias_attr=False)
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
processed_key,
|
||||
value,
|
||||
attention_weights_cat,
|
||||
mask=None):
|
||||
|
||||
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||
processed_attention_weights = self.location_layer(
|
||||
self.location_conv(attention_weights_cat))
|
||||
alignment = self.value(
|
||||
paddle.tanh(processed_attention_weights + processed_key +
|
||||
processed_query))
|
||||
|
||||
if mask is not None:
|
||||
alignment = alignment + (1.0 - mask) * -1e9
|
||||
|
||||
attention_weights = F.softmax(alignment, axis=1)
|
||||
attention_context = paddle.matmul(
|
||||
attention_weights, value, transpose_x=True)
|
||||
|
||||
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
|
||||
attention_context = paddle.squeeze(attention_context, axis=[1])
|
||||
|
||||
return attention_context, attention_weights
|
||||
|
|
|
@ -1,6 +1,21 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class Conv1dCell(nn.Conv1D):
|
||||
"""
|
||||
A subclass of Conv1d layer, which can be used like an RNN cell. It can take
|
||||
|
@ -14,6 +29,7 @@ class Conv1dCell(nn.Conv1D):
|
|||
|
||||
As a result, these arguments are removed form the initializer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
|
@ -21,8 +37,10 @@ class Conv1dCell(nn.Conv1D):
|
|||
dilation=1,
|
||||
weight_attr=None,
|
||||
bias_attr=None):
|
||||
_dilation = dilation[0] if isinstance(dilation, (tuple, list)) else dilation
|
||||
_kernel_size = kernel_size[0] if isinstance(kernel_size, (tuple, list)) else kernel_size
|
||||
_dilation = dilation[0] if isinstance(dilation,
|
||||
(tuple, list)) else dilation
|
||||
_kernel_size = kernel_size[0] if isinstance(kernel_size, (
|
||||
tuple, list)) else kernel_size
|
||||
self._r = 1 + (_kernel_size - 1) * _dilation
|
||||
super(Conv1dCell, self).__init__(
|
||||
in_channels,
|
||||
|
@ -42,8 +60,8 @@ class Conv1dCell(nn.Conv1D):
|
|||
if self.training:
|
||||
raise Exception("only use start_sequence in evaluation")
|
||||
self._buffer = None
|
||||
self._reshaped_weight = paddle.reshape(
|
||||
self.weight, (self._out_channels, -1))
|
||||
self._reshaped_weight = paddle.reshape(self.weight,
|
||||
(self._out_channels, -1))
|
||||
|
||||
def initialize_buffer(self, x_t):
|
||||
batch_size, _ = x_t.shape
|
||||
|
@ -82,20 +100,34 @@ class Conv1dCell(nn.Conv1D):
|
|||
|
||||
|
||||
class Conv1dBatchNorm(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
weight_attr=None, bias_attr=None, data_format="NCL"):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=None,
|
||||
bias_attr=None,
|
||||
data_format="NCL",
|
||||
momentum=0.9,
|
||||
epsilon=1e-05):
|
||||
super(Conv1dBatchNorm, self).__init__()
|
||||
# TODO(chenfeiyu): carefully initialize Conv1d's weight
|
||||
self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride,
|
||||
self.conv = nn.Conv1D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=padding,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr,
|
||||
data_format=data_format)
|
||||
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
||||
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format)
|
||||
self.bn = nn.BatchNorm1D(
|
||||
out_channels,
|
||||
momentum=momentum,
|
||||
epsilon=epsilon,
|
||||
data_format=data_format)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
|
Loading…
Reference in New Issue