fix a bug for changing reduction factor in transformner_tts

This commit is contained in:
chenfeiyu 2020-11-03 11:18:46 +08:00
parent 1f71f65c28
commit 0cdad602e2
5 changed files with 28 additions and 16 deletions

View File

@ -74,10 +74,11 @@ class MultiheadAttention(nn.Layer):
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads)
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
if mask is not None:
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
q, k, v, mask, training=self.training)
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
@ -110,9 +111,11 @@ class TransformerEncoderLayer(nn.Layer):
def _forward_mha(self, x, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
if mask is not None:
mask = paddle.unsqueeze(mask, 1)
x_in = x
x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
return context_vector, attn_weights
@ -292,14 +295,14 @@ class MLPPreNet(nn.Layer):
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_hidden)
self.lin3 = nn.Linear(d_output, d_output)
# self.lin3 = nn.Linear(d_output, d_output)
def forward(self, x, dropout):
# the original code said also use dropout in inference
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
l3 = self.lin3(l2)
return l3
#l3 = self.lin3(l2)
return l2
class CNNPreNet(nn.Layer):
@ -328,7 +331,6 @@ class CNNPreNet(nn.Layer):
class CNNPostNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
super(CNNPostNet, self).__init__()
self.first_norm = nn.BatchNorm1D(d_output)
self.convs = nn.LayerList()
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
padding = (kernel_size[0] - 1, 0)
@ -347,7 +349,6 @@ class CNNPostNet(nn.Layer):
def forward(self, x):
# why not use pre norms
x_in = x
x = self.first_norm(x)
for i, layer in enumerate(self.convs):
x = layer(x)
if i != (len(self.convs) - 1):
@ -362,8 +363,8 @@ class TransformerTTS(nn.Layer):
postnet_kernel_size, max_reduction_factor, dropout):
super(TransformerTTS, self).__init__()
# encoder
self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05))
self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout)
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05))
# self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout)
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
@ -381,6 +382,7 @@ class TransformerTTS(nn.Layer):
self.padding_idx = padding_idx
self.d_encoder = d_encoder
self.d_decoder = d_decoder
self.d_mel = d_mel
self.max_r = max_reduction_factor
self.dropout = dropout
@ -403,7 +405,7 @@ class TransformerTTS(nn.Layer):
def encode(self, text):
T_enc = text.shape[-1]
embed = self.encoder_prenet(self.embedding(text))
embed = self.encoder_prenet(text)
if embed.shape[1] > self.encoder_pe.shape[0]:
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
@ -439,7 +441,8 @@ class TransformerTTS(nn.Layer):
decoder_mask,
self.drop_n_heads)
output_proj = self.final_proj(decoder_output)
# use only parts of it
output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim]
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate)
@ -447,6 +450,7 @@ class TransformerTTS(nn.Layer):
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1])
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
def predict(self, input, max_length=1000, verbose=True):

View File

@ -27,7 +27,7 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
if mask is not None:
scaled_logit += paddle.scale((1.0 - mask), -1e12) # hard coded here
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
attn_weights = F.softmax(scaled_logit, axis=-1)
attn_weights = F.dropout(attn_weights, dropout, training=training)

View File

@ -92,8 +92,10 @@ class Conv1dBatchNorm(nn.Layer):
bias_attr=bias_attr,
data_format=data_format)
# TODO: channel last, but BatchNorm1d does not support channel last layout
self.bn = nn.BatchNorm1D(out_channels, data_format=data_format)
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format)
def forward(self, x):
return self.bn(self.conv(x))
x = self.conv(x)
x = self.bn(x)
return x

View File

@ -39,7 +39,9 @@ class PositionwiseFFN(nn.Layer):
Returns:
Tensor: shape(*, input_size), the output tensor.
"""
return self.linear2(self.dropout(F.relu(self.linear1(x))))
l1 = self.dropout(F.relu(self.linear1(x)))
l2 = self.linear2(l1)
return l2
class TransformerEncoderLayer(nn.Layer):

View File

@ -3,6 +3,7 @@ import matplotlib
from matplotlib import cm, pyplot
def pack_attention_images(attention_weights, rotate=False):
# add a box
attention_weights = np.pad(attention_weights,
[(0, 0), (1, 1), (1, 1)],
mode="constant",
@ -25,3 +26,6 @@ def pack_attention_images(attention_weights, rotate=False):
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
return img
def min_max_normalize(v):
return (v - v.min()) / (v.max() - v.min())