add network of tacotron2 model

This commit is contained in:
lfchener 2020-12-09 09:08:17 +00:00
parent f255eee029
commit b12eda8423
3 changed files with 372 additions and 173 deletions

View File

@ -1,3 +1,17 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from tqdm import trange from tqdm import trange
import paddle import paddle
@ -15,6 +29,7 @@ from parakeet.modules import losses as L
__all__ = ["TransformerTTS", "TransformerTTSLoss"] __all__ = ["TransformerTTS", "TransformerTTSLoss"]
# Transformer TTS's own implementation of transformer # Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer): class MultiheadAttention(nn.Layer):
""" """
@ -25,7 +40,14 @@ class MultiheadAttention(nn.Layer):
Another deviation is that it concats the input query and context vector before Another deviation is that it concats the input query and context vector before
applying the output projection. applying the output projection.
""" """
def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None):
def __init__(self,
model_dim,
num_heads,
k_dim=None,
v_dim=None,
k_input_dim=None,
v_input_dim=None):
""" """
Args: Args:
model_dim (int): the feature size of query. model_dim (int): the feature size of query.
@ -41,7 +63,7 @@ class MultiheadAttention(nn.Layer):
ValueError: if model_dim is not divisible by num_heads ValueError: if model_dim is not divisible by num_heads
""" """
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
if model_dim % num_heads !=0: if model_dim % num_heads != 0:
raise ValueError("model_dim must be divisible by num_heads") raise ValueError("model_dim must be divisible by num_heads")
depth = model_dim // num_heads depth = model_dim // num_heads
k_dim = k_dim or depth k_dim = k_dim or depth
@ -52,10 +74,10 @@ class MultiheadAttention(nn.Layer):
self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim) self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim) self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim) self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)
self.num_heads = num_heads self.num_heads = num_heads
self.model_dim = model_dim self.model_dim = model_dim
def forward(self, q, k, v, mask, drop_n_heads=0): def forward(self, q, k, v, mask, drop_n_heads=0):
""" """
Compute context vector and attention weights. Compute context vector and attention weights.
@ -72,17 +94,18 @@ class MultiheadAttention(nn.Layer):
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
""" """
q_in = q q_in = q
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads) k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads) v = _split_heads(self.affine_v(v), self.num_heads)
if mask is not None: if mask is not None:
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
context_vectors, attention_weights = scaled_dot_product_attention( context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, training=self.training) q, k, v, mask, training=self.training)
context_vectors = drop_head(context_vectors, drop_n_heads, self.training) context_vectors = drop_head(context_vectors, drop_n_heads,
context_vectors = _concat_heads(context_vectors) # (B, T, h*C) self.training)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
concat_feature = paddle.concat([q_in, context_vectors], -1) concat_feature = paddle.concat([q_in, context_vectors], -1)
out = self.affine_o(concat_feature) out = self.affine_o(concat_feature)
return out, attention_weights return out, attention_weights
@ -92,6 +115,7 @@ class TransformerEncoderLayer(nn.Layer):
""" """
Transformer encoder layer. Transformer encoder layer.
""" """
def __init__(self, d_model, n_heads, d_ffn, dropout=0.): def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
""" """
Args: Args:
@ -114,8 +138,10 @@ class TransformerEncoderLayer(nn.Layer):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x x_in = x
x = self.layer_norm1(x) x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) context_vector, attn_weights = self.self_mha(x, x, x, mask,
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) drop_n_heads)
context_vector = x_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
def _forward_ffn(self, x): def _forward_ffn(self, x):
@ -123,9 +149,9 @@ class TransformerEncoderLayer(nn.Layer):
x_in = x x_in = x
x = self.layer_norm2(x) x = self.layer_norm2(x)
x = self.ffn(x) x = self.ffn(x)
out= x_in + F.dropout(x, self.dropout, training=self.training) out = x_in + F.dropout(x, self.dropout, training=self.training)
return out return out
def forward(self, x, mask, drop_n_heads=0): def forward(self, x, mask, drop_n_heads=0):
""" """
Args: Args:
@ -145,6 +171,7 @@ class TransformerDecoderLayer(nn.Layer):
""" """
Transformer decoder layer. Transformer decoder layer.
""" """
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None): def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
""" """
Args: Args:
@ -157,37 +184,42 @@ class TransformerDecoderLayer(nn.Layer):
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayer, self).__init__()
self.self_mha = MultiheadAttention(d_model, n_heads) self.self_mha = MultiheadAttention(d_model, n_heads)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder) self.cross_mha = MultiheadAttention(
d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout) self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
self.dropout = dropout self.dropout = dropout
def _forward_self_mha(self, x, mask, drop_n_heads): def _forward_self_mha(self, x, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x x_in = x
x = self.layer_norm1(x) x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) context_vector, attn_weights = self.self_mha(x, x, x, mask,
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) drop_n_heads)
context_vector = x_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads): def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
q_in = q q_in = q
q = self.layer_norm2(q) q = self.layer_norm2(q)
context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads) context_vector, attn_weights = self.cross_mha(q, k, v, mask,
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training) drop_n_heads)
context_vector = q_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
def _forward_ffn(self, x): def _forward_ffn(self, x):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x x_in = x
x = self.layer_norm3(x) x = self.layer_norm3(x)
x = self.ffn(x) x = self.ffn(x)
out= x_in + F.dropout(x, self.dropout, training=self.training) out = x_in + F.dropout(x, self.dropout, training=self.training)
return out return out
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0): def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
@ -204,8 +236,10 @@ class TransformerDecoderLayer(nn.Layer):
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
""" """
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads) q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads) drop_n_heads)
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
drop_n_heads)
q = self._forward_ffn(q) q = self._forward_ffn(q)
return q, self_attn_weights, cross_attn_weights return q, self_attn_weights, cross_attn_weights
@ -214,7 +248,8 @@ class TransformerEncoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.): def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
for _ in range(n_layers): for _ in range(n_layers):
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout)) self.append(
TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
def forward(self, x, mask, drop_n_heads=0): def forward(self, x, mask, drop_n_heads=0):
""" """
@ -236,10 +271,18 @@ class TransformerEncoder(nn.LayerList):
class TransformerDecoder(nn.LayerList): class TransformerDecoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None): def __init__(self,
d_model,
n_heads,
d_ffn,
n_layers,
dropout=0.,
d_encoder=None):
super(TransformerDecoder, self).__init__() super(TransformerDecoder, self).__init__()
for _ in range(n_layers): for _ in range(n_layers):
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder)) self.append(
TransformerDecoderLayer(
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0): def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
"""[summary] """[summary]
@ -260,7 +303,8 @@ class TransformerDecoder(nn.LayerList):
self_attention_weights = [] self_attention_weights = []
cross_attention_weights = [] cross_attention_weights = []
for layer in self: for layer in self:
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads) q, self_attention_weights_i, cross_attention_weights_i = layer(
q, k, v, encoder_mask, decoder_mask, drop_n_heads)
self_attention_weights.append(self_attention_weights_i) self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i) cross_attention_weights.append(cross_attention_weights_i)
return q, self_attention_weights, cross_attention_weights return q, self_attention_weights, cross_attention_weights
@ -268,6 +312,7 @@ class TransformerDecoder(nn.LayerList):
class MLPPreNet(nn.Layer): class MLPPreNet(nn.Layer):
"""Decoder's prenet.""" """Decoder's prenet."""
def __init__(self, d_input, d_hidden, d_output, dropout): def __init__(self, d_input, d_hidden, d_output, dropout):
# (lin + relu + dropout) * n + last projection # (lin + relu + dropout) * n + last projection
super(MLPPreNet, self).__init__() super(MLPPreNet, self).__init__()
@ -275,16 +320,24 @@ class MLPPreNet(nn.Layer):
self.lin2 = nn.Linear(d_hidden, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden)
self.lin3 = nn.Linear(d_hidden, d_hidden) self.lin3 = nn.Linear(d_hidden, d_hidden)
self.dropout = dropout self.dropout = dropout
def forward(self, x, dropout): def forward(self, x, dropout):
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training) l1 = F.dropout(
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training) F.relu(self.lin1(x)), self.dropout, training=self.training)
l2 = F.dropout(
F.relu(self.lin2(l1)), self.dropout, training=self.training)
l3 = self.lin3(l2) l3 = self.lin3(l2)
return l3 return l3
# NOTE: not used in # NOTE: not used in
class CNNPreNet(nn.Layer): class CNNPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers, def __init__(self,
d_input,
d_hidden,
d_output,
kernel_size,
n_layers,
dropout=0.): dropout=0.):
# (conv + bn + relu + dropout) * n + last projection # (conv + bn + relu + dropout) * n + last projection
super(CNNPreNet, self).__init__() super(CNNPreNet, self).__init__()
@ -292,16 +345,21 @@ class CNNPreNet(nn.Layer):
c_in = d_input c_in = d_input
for _ in range(n_layers): for _ in range(n_layers):
self.convs.append( self.convs.append(
Conv1dBatchNorm(c_in, d_hidden, kernel_size, Conv1dBatchNorm(
weight_attr=I.XavierUniform(), c_in,
padding="same", data_format="NLC")) d_hidden,
kernel_size,
weight_attr=I.XavierUniform(),
padding="same",
data_format="NLC"))
c_in = d_hidden c_in = d_hidden
self.affine_out = nn.Linear(d_hidden, d_output) self.affine_out = nn.Linear(d_hidden, d_output)
self.dropout = dropout self.dropout = dropout
def forward(self, x): def forward(self, x):
for layer in self.convs: for layer in self.convs:
x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training) x = F.dropout(
F.relu(layer(x)), self.dropout, training=self.training)
x = self.affine_out(x) x = self.affine_out(x)
return x return x
@ -310,21 +368,25 @@ class CNNPostNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
super(CNNPostNet, self).__init__() super(CNNPostNet, self).__init__()
self.convs = nn.LayerList() self.convs = nn.LayerList()
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, ) kernel_size = kernel_size if isinstance(kernel_size, (
tuple, list)) else (kernel_size, )
padding = (kernel_size[0] - 1, 0) padding = (kernel_size[0] - 1, 0)
for i in range(n_layers): for i in range(n_layers):
c_in = d_input if i == 0 else d_hidden c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append( self.convs.append(
Conv1dBatchNorm(c_in, c_out, kernel_size, Conv1dBatchNorm(
weight_attr=I.XavierUniform(), c_in,
padding=padding)) c_out,
kernel_size,
weight_attr=I.XavierUniform(),
padding=padding))
self.last_bn = nn.BatchNorm1D(d_output) self.last_bn = nn.BatchNorm1D(d_output)
# for a layer that ends with a normalization layer that is targeted to # for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to # output a non zero-central output, it may take a long time to
# train the scale and bias # train the scale and bias
# NOTE: it can also be a non-causal conv # NOTE: it can also be a non-causal conv
def forward(self, x): def forward(self, x):
x_in = x x_in = x
for i, layer in enumerate(self.convs): for i, layer in enumerate(self.convs):
@ -336,19 +398,19 @@ class CNNPostNet(nn.Layer):
class TransformerTTS(nn.Layer): class TransformerTTS(nn.Layer):
def __init__(self, def __init__(self,
frontend: parakeet.frontend.Phonetics, frontend: parakeet.frontend.Phonetics,
d_encoder: int, d_encoder: int,
d_decoder: int, d_decoder: int,
d_mel: int, d_mel: int,
n_heads: int, n_heads: int,
d_ffn: int, d_ffn: int,
encoder_layers: int, encoder_layers: int,
decoder_layers: int, decoder_layers: int,
d_prenet: int, d_prenet: int,
d_postnet: int, d_postnet: int,
postnet_layers: int, postnet_layers: int,
postnet_kernel_size: int, postnet_kernel_size: int,
max_reduction_factor: int, max_reduction_factor: int,
decoder_prenet_dropout: float, decoder_prenet_dropout: float,
dropout: float): dropout: float):
@ -359,29 +421,34 @@ class TransformerTTS(nn.Layer):
# encoder # encoder
self.encoder_prenet = nn.Embedding( self.encoder_prenet = nn.Embedding(
frontend.vocab_size, d_encoder, frontend.vocab_size,
padding_idx=frontend.vocab.padding_index, d_encoder,
padding_idx=frontend.vocab.padding_index,
weight_attr=I.Uniform(-0.05, 0.05)) weight_attr=I.Uniform(-0.05, 0.05))
# position encoding matrix may be extended later # position encoding matrix may be extended later
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter( self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.)) [1], attr=I.Constant(1.))
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
d_encoder, n_heads, d_ffn, encoder_layers, dropout) encoder_layers, dropout)
# decoder # decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
self.decoder_pe_scalar = self.create_parameter( self.decoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.)) [1], attr=I.Constant(1.))
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_decoder,
n_heads,
d_ffn,
decoder_layers,
dropout,
d_encoder=d_encoder) d_encoder=d_encoder)
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel) self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
self.decoder_postnet = CNNPostNet( self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel,
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) postnet_kernel_size, postnet_layers)
self.stop_conditioner = nn.Linear(d_mel, 3) self.stop_conditioner = nn.Linear(d_mel, 3)
# specs # specs
self.padding_idx = frontend.vocab.padding_index self.padding_idx = frontend.vocab.padding_index
self.d_encoder = d_encoder self.d_encoder = d_encoder
@ -390,21 +457,22 @@ class TransformerTTS(nn.Layer):
self.max_r = max_reduction_factor self.max_r = max_reduction_factor
self.dropout = dropout self.dropout = dropout
self.decoder_prenet_dropout = decoder_prenet_dropout self.decoder_prenet_dropout = decoder_prenet_dropout
# start and end: though it is only used in predict # start and end: though it is only used in predict
# it can also be used in training # it can also be used in training
dtype = paddle.get_default_dtype() dtype = paddle.get_default_dtype()
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype) self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype) self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
self.stop_prob_index = 2 self.stop_prob_index = 2
# mutables # mutables
self.r = max_reduction_factor # set it every call self.r = max_reduction_factor # set it every call
self.drop_n_heads = 0 self.drop_n_heads = 0
def forward(self, text, mel): def forward(self, text, mel):
encoded, encoder_attention_weights, encoder_mask = self.encode(text) encoded, encoder_attention_weights, encoder_mask = self.encode(text)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask) mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
encoded, mel, encoder_mask)
outputs = { outputs = {
"mel_output": mel_output, "mel_output": mel_output,
"mel_intermediate": mel_intermediate, "mel_intermediate": mel_intermediate,
@ -420,51 +488,54 @@ class TransformerTTS(nn.Layer):
if embed.shape[1] > self.encoder_pe.shape[0]: if embed.shape[1] > self.encoder_pe.shape[0]:
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2) new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder) self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
pos_enc = self.encoder_pe[:T_enc, :] # (T, C) pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar x = embed.scale(math.sqrt(
self.d_encoder)) + pos_enc * self.encoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training) x = F.dropout(x, self.dropout, training=self.training)
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask # TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
encoder_padding_mask = paddle.unsqueeze( encoder_padding_mask = paddle.unsqueeze(
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1) masking.id_mask(
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads) text, self.padding_idx, dtype=x.dtype), 1)
x, attention_weights = self.encoder(x, encoder_padding_mask,
self.drop_n_heads)
return x, attention_weights, encoder_padding_mask return x, attention_weights, encoder_padding_mask
def decode(self, encoder_output, input, encoder_padding_mask): def decode(self, encoder_output, input, encoder_padding_mask):
batch_size, T_dec, mel_dim = input.shape batch_size, T_dec, mel_dim = input.shape
x = self.decoder_prenet(input, self.decoder_prenet_dropout) x = self.decoder_prenet(input, self.decoder_prenet_dropout)
# twice its length if needed # twice its length if needed
if x.shape[1] * self.r > self.decoder_pe.shape[0]: if x.shape[1] * self.r > self.decoder_pe.shape[0]:
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2) new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder) self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
pos_enc = self.decoder_pe[:T_dec*self.r:self.r, :] pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar x = x.scale(math.sqrt(
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training) x = F.dropout(x, self.dropout, training=self.training)
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype) no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype) decoder_padding_mask = masking.feature_mask(
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask) input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(
decoder_padding_mask.unsqueeze(-1), no_future_mask)
decoder_output, _, cross_attention_weights = self.decoder( decoder_output, _, cross_attention_weights = self.decoder(
x, x, encoder_output, encoder_output, encoder_padding_mask,
encoder_output, decoder_mask, self.drop_n_heads)
encoder_output,
encoder_padding_mask,
decoder_mask,
self.drop_n_heads)
# use only parts of it # use only parts of it
output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim] output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim]
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim]) mel_intermediate = paddle.reshape(output_proj,
[batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate) stop_logits = self.stop_conditioner(mel_intermediate)
# cnn postnet # cnn postnet
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1]) mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first) mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1]) mel_output = paddle.transpose(mel_output, [0, 2, 1])
return mel_output, mel_intermediate, cross_attention_weights, stop_logits return mel_output, mel_intermediate, cross_attention_weights, stop_logits
def predict(self, input, raw_input=True, max_length=1000, verbose=True): def predict(self, input, raw_input=True, max_length=1000, verbose=True):
"""Predict log scale magnitude mel spectrogram from text input. """Predict log scale magnitude mel spectrogram from text input.
@ -475,26 +546,32 @@ class TransformerTTS(nn.Layer):
""" """
if raw_input: if raw_input:
text_ids = paddle.to_tensor(self.frontend(input)) text_ids = paddle.to_tensor(self.frontend(input))
text_input = paddle.unsqueeze(text_ids, 0) # (1, T) text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
else: else:
text_input = input text_input = input
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
# encoder the text sequence # encoder the text sequence
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input) encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
for _ in trange(int(max_length // self.r) + 1): text_input)
for _ in range(int(max_length // self.r) + 1):
mel_output, _, cross_attention_weights, stop_logits = self.decode( mel_output, _, cross_attention_weights, stop_logits = self.decode(
encoder_output, decoder_input, encoder_padding_mask) encoder_output, decoder_input, encoder_padding_mask)
# extract last step and append it to decoder input # extract last step and append it to decoder input
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1) decoder_input = paddle.concat(
[decoder_input, mel_output[:, -1:, :]], 1)
# extract last r steps and append it to decoder output # extract last r steps and append it to decoder output
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) decoder_output = paddle.concat(
[decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition: (if any ouput frame of the output multiframes hits the stop condition) # stop condition: (if any ouput frame of the output multiframes hits the stop condition)
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index): if paddle.any(
paddle.argmax(
stop_logits[0, -self.r:, :], axis=-1) ==
self.stop_prob_index):
if verbose: if verbose:
print("Hits stop condition.") print("Hits stop condition.")
break break
@ -516,24 +593,28 @@ class TransformerTTSLoss(nn.Layer):
def __init__(self, stop_loss_scale): def __init__(self, stop_loss_scale):
super(TransformerTTSLoss, self).__init__() super(TransformerTTSLoss, self).__init__()
self.stop_loss_scale = stop_loss_scale self.stop_loss_scale = stop_loss_scale
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) stop_probs):
mask = masking.feature_mask(
mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1) mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
mel_len = mask.shape[-1] mel_len = mask.shape[-1]
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len) last_position = F.one_hot(
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype) mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(
mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy( stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss loss = mel_loss1 + mel_loss2 + stop_loss
losses = dict( losses = dict(
loss=loss, # total loss loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss stop_loss=stop_loss # stop prob loss
) )
return losses return losses
@ -542,26 +623,29 @@ class TransformerTTSLoss(nn.Layer):
class AdaptiveTransformerTTSLoss(nn.Layer): class AdaptiveTransformerTTSLoss(nn.Layer):
def __init__(self): def __init__(self):
super(AdaptiveTransformerTTSLoss, self).__init__() super(AdaptiveTransformerTTSLoss, self).__init__()
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) stop_probs):
mask = masking.feature_mask(
mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1) mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
batch_size, mel_len = mask.shape batch_size, mel_len = mask.shape
valid_lengths = mask.sum(-1).astype("int64") valid_lengths = mask.sum(-1).astype("int64")
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len) last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
stop_loss_scale = valid_lengths.sum() / batch_size - 1 stop_loss_scale = valid_lengths.sum() / batch_size - 1
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype) mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(
mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy( stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss loss = mel_loss1 + mel_loss2 + stop_loss
losses = dict( losses = dict(
loss=loss, # total loss loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss stop_loss=stop_loss # stop prob loss
) )
return losses return losses

View File

@ -1,10 +1,30 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True):
def scaled_dot_product_attention(q,
k,
v,
mask=None,
dropout=0.0,
training=True):
""" """
scaled dot product attention with mask. Assume q, k, v all have the same scaled dot product attention with mask. Assume q, k, v all have the same
leader dimensions(denoted as * in descriptions below). Dropout is applied to leader dimensions(denoted as * in descriptions below). Dropout is applied to
@ -22,18 +42,19 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
out (Tensor): shape(*, T_q, d_v), the context vector. out (Tensor): shape(*, T_q, d_v), the context vector.
attn_weights (Tensor): shape(*, T_q, T_k), the attention weights. attn_weights (Tensor): shape(*, T_q, T_k), the attention weights.
""" """
d = q.shape[-1] # we only support imperative execution d = q.shape[-1] # we only support imperative execution
qk = paddle.matmul(q, k, transpose_y=True) qk = paddle.matmul(q, k, transpose_y=True)
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d)) scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
if mask is not None: if mask is not None:
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
attn_weights = F.softmax(scaled_logit, axis=-1) attn_weights = F.softmax(scaled_logit, axis=-1)
attn_weights = F.dropout(attn_weights, dropout, training=training) attn_weights = F.dropout(attn_weights, dropout, training=training)
out = paddle.matmul(attn_weights, v) out = paddle.matmul(attn_weights, v)
return out, attn_weights return out, attn_weights
def drop_head(x, drop_n_heads, training): def drop_head(x, drop_n_heads, training):
""" """
Drop n heads from multiple context vectors. Drop n heads from multiple context vectors.
@ -48,12 +69,12 @@ def drop_head(x, drop_n_heads, training):
""" """
if not training or (drop_n_heads == 0): if not training or (drop_n_heads == 0):
return x return x
batch_size, num_heads, _, _ = x.shape batch_size, num_heads, _, _ = x.shape
# drop all heads # drop all heads
if num_heads == drop_n_heads: if num_heads == drop_n_heads:
return paddle.zeros_like(x) return paddle.zeros_like(x)
mask = np.ones([batch_size, num_heads]) mask = np.ones([batch_size, num_heads])
mask[:, :drop_n_heads] = 0 mask[:, :drop_n_heads] = 0
for subarray in mask: for subarray in mask:
@ -63,18 +84,21 @@ def drop_head(x, drop_n_heads, training):
out = x * paddle.to_tensor(mask) out = x * paddle.to_tensor(mask)
return out return out
def _split_heads(x, num_heads): def _split_heads(x, num_heads):
batch_size, time_steps, _ = x.shape batch_size, time_steps, _ = x.shape
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1]) x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
x = paddle.transpose(x, [0, 2, 1, 3]) x = paddle.transpose(x, [0, 2, 1, 3])
return x return x
def _concat_heads(x): def _concat_heads(x):
batch_size, _, time_steps, _ = x.shape batch_size, _, time_steps, _ = x.shape
x = paddle.transpose(x, [0, 2, 1, 3]) x = paddle.transpose(x, [0, 2, 1, 3])
x = paddle.reshape(x, [batch_size, time_steps, -1]) x = paddle.reshape(x, [batch_size, time_steps, -1])
return x return x
# Standard implementations of Monohead Attention & Multihead Attention # Standard implementations of Monohead Attention & Multihead Attention
class MonoheadAttention(nn.Layer): class MonoheadAttention(nn.Layer):
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None): def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
@ -99,10 +123,10 @@ class MonoheadAttention(nn.Layer):
self.affine_k = nn.Linear(model_dim, k_dim) self.affine_k = nn.Linear(model_dim, k_dim)
self.affine_v = nn.Linear(model_dim, v_dim) self.affine_v = nn.Linear(model_dim, v_dim)
self.affine_o = nn.Linear(v_dim, model_dim) self.affine_o = nn.Linear(v_dim, model_dim)
self.model_dim = model_dim self.model_dim = model_dim
self.dropout = dropout self.dropout = dropout
def forward(self, q, k, v, mask): def forward(self, q, k, v, mask):
""" """
Compute context vector and attention weights. Compute context vector and attention weights.
@ -119,22 +143,28 @@ class MonoheadAttention(nn.Layer):
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
""" """
q = self.affine_q(q) # (B, T, C) q = self.affine_q(q) # (B, T, C)
k = self.affine_k(k) k = self.affine_k(k)
v = self.affine_v(v) v = self.affine_v(v)
context_vectors, attention_weights = scaled_dot_product_attention( context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout, self.training) q, k, v, mask, self.dropout, self.training)
out = self.affine_o(context_vectors) out = self.affine_o(context_vectors)
return out, attention_weights return out, attention_weights
class MultiheadAttention(nn.Layer): class MultiheadAttention(nn.Layer):
""" """
Multihead scaled dot product attention. Multihead scaled dot product attention.
""" """
def __init__(self, model_dim, num_heads, dropout=0.0, k_dim=None, v_dim=None):
def __init__(self,
model_dim,
num_heads,
dropout=0.0,
k_dim=None,
v_dim=None):
""" """
Multihead Attention module. Multihead Attention module.
@ -154,7 +184,7 @@ class MultiheadAttention(nn.Layer):
ValueError: if model_dim is not divisible by num_heads ValueError: if model_dim is not divisible by num_heads
""" """
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
if model_dim % num_heads !=0: if model_dim % num_heads != 0:
raise ValueError("model_dim must be divisible by num_heads") raise ValueError("model_dim must be divisible by num_heads")
depth = model_dim // num_heads depth = model_dim // num_heads
k_dim = k_dim or depth k_dim = k_dim or depth
@ -163,11 +193,11 @@ class MultiheadAttention(nn.Layer):
self.affine_k = nn.Linear(model_dim, num_heads * k_dim) self.affine_k = nn.Linear(model_dim, num_heads * k_dim)
self.affine_v = nn.Linear(model_dim, num_heads * v_dim) self.affine_v = nn.Linear(model_dim, num_heads * v_dim)
self.affine_o = nn.Linear(num_heads * v_dim, model_dim) self.affine_o = nn.Linear(num_heads * v_dim, model_dim)
self.num_heads = num_heads self.num_heads = num_heads
self.model_dim = model_dim self.model_dim = model_dim
self.dropout = dropout self.dropout = dropout
def forward(self, q, k, v, mask): def forward(self, q, k, v, mask):
""" """
Compute context vector and attention weights. Compute context vector and attention weights.
@ -184,14 +214,67 @@ class MultiheadAttention(nn.Layer):
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
""" """
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads) k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads) v = _split_heads(self.affine_v(v), self.num_heads)
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
context_vectors, attention_weights = scaled_dot_product_attention( context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout, self.training) q, k, v, mask, self.dropout, self.training)
# NOTE: there is more sophisticated implementation: Scheduled DropHead # NOTE: there is more sophisticated implementation: Scheduled DropHead
context_vectors = _concat_heads(context_vectors) # (B, T, h*C) context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
out = self.affine_o(context_vectors) out = self.affine_o(context_vectors)
return out, attention_weights return out, attention_weights
class LocationSensitiveAttention(nn.Layer):
def __init__(self,
d_query: int,
d_key: int,
d_attention: int,
location_filters: int,
location_kernel_size: int):
super().__init__()
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
self.value = nn.Linear(d_attention, 1, bias_attr=False)
#Location Layer
self.location_conv = nn.Conv1D(
2,
location_filters,
location_kernel_size,
1,
int((location_kernel_size - 1) / 2),
1,
bias_attr=False,
data_format='NLC')
self.location_layer = nn.Linear(
location_filters, d_attention, bias_attr=False)
def forward(self,
query,
processed_key,
value,
attention_weights_cat,
mask=None):
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
processed_attention_weights = self.location_layer(
self.location_conv(attention_weights_cat))
alignment = self.value(
paddle.tanh(processed_attention_weights + processed_key +
processed_query))
if mask is not None:
alignment = alignment + (1.0 - mask) * -1e9
attention_weights = F.softmax(alignment, axis=1)
attention_context = paddle.matmul(
attention_weights, value, transpose_x=True)
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
attention_context = paddle.squeeze(attention_context, axis=[1])
return attention_context, attention_weights

View File

@ -1,6 +1,21 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
from paddle import nn from paddle import nn
class Conv1dCell(nn.Conv1D): class Conv1dCell(nn.Conv1D):
""" """
A subclass of Conv1d layer, which can be used like an RNN cell. It can take A subclass of Conv1d layer, which can be used like an RNN cell. It can take
@ -14,43 +29,46 @@ class Conv1dCell(nn.Conv1D):
As a result, these arguments are removed form the initializer. As a result, these arguments are removed form the initializer.
""" """
def __init__(self,
def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
dilation=1, dilation=1,
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None):
_dilation = dilation[0] if isinstance(dilation, (tuple, list)) else dilation _dilation = dilation[0] if isinstance(dilation,
_kernel_size = kernel_size[0] if isinstance(kernel_size, (tuple, list)) else kernel_size (tuple, list)) else dilation
_kernel_size = kernel_size[0] if isinstance(kernel_size, (
tuple, list)) else kernel_size
self._r = 1 + (_kernel_size - 1) * _dilation self._r = 1 + (_kernel_size - 1) * _dilation
super(Conv1dCell, self).__init__( super(Conv1dCell, self).__init__(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
padding=(self._r - 1, 0), padding=(self._r - 1, 0),
dilation=dilation, dilation=dilation,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
data_format="NCL") data_format="NCL")
@property @property
def receptive_field(self): def receptive_field(self):
return self._r return self._r
def start_sequence(self): def start_sequence(self):
if self.training: if self.training:
raise Exception("only use start_sequence in evaluation") raise Exception("only use start_sequence in evaluation")
self._buffer = None self._buffer = None
self._reshaped_weight = paddle.reshape( self._reshaped_weight = paddle.reshape(self.weight,
self.weight, (self._out_channels, -1)) (self._out_channels, -1))
def initialize_buffer(self, x_t): def initialize_buffer(self, x_t):
batch_size, _ = x_t.shape batch_size, _ = x_t.shape
self._buffer = paddle.zeros( self._buffer = paddle.zeros(
(batch_size, self._in_channels, self.receptive_field), (batch_size, self._in_channels, self.receptive_field),
dtype=x_t.dtype) dtype=x_t.dtype)
def update_buffer(self, x_t): def update_buffer(self, x_t):
self._buffer = paddle.concat( self._buffer = paddle.concat(
[self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1) [self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1)
@ -66,7 +84,7 @@ class Conv1dCell(nn.Conv1D):
if self.receptive_field > 1: if self.receptive_field > 1:
if self._buffer is None: if self._buffer is None:
self.initialize_buffer(x_t) self.initialize_buffer(x_t)
# update buffer # update buffer
self.update_buffer(x_t) self.update_buffer(x_t)
if self._dilation[0] > 1: if self._dilation[0] > 1:
@ -82,20 +100,34 @@ class Conv1dCell(nn.Conv1D):
class Conv1dBatchNorm(nn.Layer): class Conv1dBatchNorm(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def __init__(self,
weight_attr=None, bias_attr=None, data_format="NCL"): in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
weight_attr=None,
bias_attr=None,
data_format="NCL",
momentum=0.9,
epsilon=1e-05):
super(Conv1dBatchNorm, self).__init__() super(Conv1dBatchNorm, self).__init__()
# TODO(chenfeiyu): carefully initialize Conv1d's weight self.conv = nn.Conv1D(
self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride, in_channels,
padding=padding, out_channels,
weight_attr=weight_attr, kernel_size,
bias_attr=bias_attr, stride,
data_format=data_format) padding=padding,
# TODO: channel last, but BatchNorm1d does not support channel last layout weight_attr=weight_attr,
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format) bias_attr=bias_attr,
data_format=data_format)
self.bn = nn.BatchNorm1D(
out_channels,
momentum=momentum,
epsilon=epsilon,
data_format=data_format)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
return x return x