fix a bug for changing reduction factor in transformner_tts
This commit is contained in:
parent
1f71f65c28
commit
0cdad602e2
|
@ -74,10 +74,11 @@ class MultiheadAttention(nn.Layer):
|
||||||
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
||||||
k = _split_heads(self.affine_k(k), self.num_heads)
|
k = _split_heads(self.affine_k(k), self.num_heads)
|
||||||
v = _split_heads(self.affine_v(v), self.num_heads)
|
v = _split_heads(self.affine_v(v), self.num_heads)
|
||||||
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
if mask is not None:
|
||||||
|
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
||||||
|
|
||||||
context_vectors, attention_weights = scaled_dot_product_attention(
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||||
q, k, v, mask)
|
q, k, v, mask, training=self.training)
|
||||||
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
|
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
|
||||||
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
||||||
|
|
||||||
|
@ -110,9 +111,11 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
|
|
||||||
def _forward_mha(self, x, mask, drop_n_heads):
|
def _forward_mha(self, x, mask, drop_n_heads):
|
||||||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
|
if mask is not None:
|
||||||
|
mask = paddle.unsqueeze(mask, 1)
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.layer_norm1(x)
|
x = self.layer_norm1(x)
|
||||||
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads)
|
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
||||||
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
||||||
return context_vector, attn_weights
|
return context_vector, attn_weights
|
||||||
|
|
||||||
|
@ -292,14 +295,14 @@ class MLPPreNet(nn.Layer):
|
||||||
super(MLPPreNet, self).__init__()
|
super(MLPPreNet, self).__init__()
|
||||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||||
self.lin3 = nn.Linear(d_output, d_output)
|
# self.lin3 = nn.Linear(d_output, d_output)
|
||||||
|
|
||||||
def forward(self, x, dropout):
|
def forward(self, x, dropout):
|
||||||
# the original code said also use dropout in inference
|
# the original code said also use dropout in inference
|
||||||
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
|
l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training)
|
||||||
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
|
l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training)
|
||||||
l3 = self.lin3(l2)
|
#l3 = self.lin3(l2)
|
||||||
return l3
|
return l2
|
||||||
|
|
||||||
|
|
||||||
class CNNPreNet(nn.Layer):
|
class CNNPreNet(nn.Layer):
|
||||||
|
@ -328,7 +331,6 @@ class CNNPreNet(nn.Layer):
|
||||||
class CNNPostNet(nn.Layer):
|
class CNNPostNet(nn.Layer):
|
||||||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
||||||
super(CNNPostNet, self).__init__()
|
super(CNNPostNet, self).__init__()
|
||||||
self.first_norm = nn.BatchNorm1D(d_output)
|
|
||||||
self.convs = nn.LayerList()
|
self.convs = nn.LayerList()
|
||||||
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
|
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
|
||||||
padding = (kernel_size[0] - 1, 0)
|
padding = (kernel_size[0] - 1, 0)
|
||||||
|
@ -347,7 +349,6 @@ class CNNPostNet(nn.Layer):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# why not use pre norms
|
# why not use pre norms
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.first_norm(x)
|
|
||||||
for i, layer in enumerate(self.convs):
|
for i, layer in enumerate(self.convs):
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
if i != (len(self.convs) - 1):
|
if i != (len(self.convs) - 1):
|
||||||
|
@ -362,8 +363,8 @@ class TransformerTTS(nn.Layer):
|
||||||
postnet_kernel_size, max_reduction_factor, dropout):
|
postnet_kernel_size, max_reduction_factor, dropout):
|
||||||
super(TransformerTTS, self).__init__()
|
super(TransformerTTS, self).__init__()
|
||||||
# encoder
|
# encoder
|
||||||
self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05))
|
self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05))
|
||||||
self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout)
|
# self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout)
|
||||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
|
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later
|
||||||
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))
|
||||||
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
||||||
|
@ -381,6 +382,7 @@ class TransformerTTS(nn.Layer):
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.d_encoder = d_encoder
|
self.d_encoder = d_encoder
|
||||||
self.d_decoder = d_decoder
|
self.d_decoder = d_decoder
|
||||||
|
self.d_mel = d_mel
|
||||||
self.max_r = max_reduction_factor
|
self.max_r = max_reduction_factor
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
|
@ -403,7 +405,7 @@ class TransformerTTS(nn.Layer):
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
T_enc = text.shape[-1]
|
T_enc = text.shape[-1]
|
||||||
embed = self.encoder_prenet(self.embedding(text))
|
embed = self.encoder_prenet(text)
|
||||||
if embed.shape[1] > self.encoder_pe.shape[0]:
|
if embed.shape[1] > self.encoder_pe.shape[0]:
|
||||||
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
||||||
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
||||||
|
@ -439,7 +441,8 @@ class TransformerTTS(nn.Layer):
|
||||||
decoder_mask,
|
decoder_mask,
|
||||||
self.drop_n_heads)
|
self.drop_n_heads)
|
||||||
|
|
||||||
output_proj = self.final_proj(decoder_output)
|
# use only parts of it
|
||||||
|
output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim]
|
||||||
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
||||||
stop_logits = self.stop_conditioner(mel_intermediate)
|
stop_logits = self.stop_conditioner(mel_intermediate)
|
||||||
|
|
||||||
|
@ -447,6 +450,7 @@ class TransformerTTS(nn.Layer):
|
||||||
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
||||||
mel_output = self.decoder_postnet(mel_channel_first)
|
mel_output = self.decoder_postnet(mel_channel_first)
|
||||||
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
||||||
|
|
||||||
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
||||||
|
|
||||||
def predict(self, input, max_length=1000, verbose=True):
|
def predict(self, input, max_length=1000, verbose=True):
|
||||||
|
|
|
@ -27,7 +27,7 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
|
||||||
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
|
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scaled_logit += paddle.scale((1.0 - mask), -1e12) # hard coded here
|
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
|
||||||
|
|
||||||
attn_weights = F.softmax(scaled_logit, axis=-1)
|
attn_weights = F.softmax(scaled_logit, axis=-1)
|
||||||
attn_weights = F.dropout(attn_weights, dropout, training=training)
|
attn_weights = F.dropout(attn_weights, dropout, training=training)
|
||||||
|
|
|
@ -92,8 +92,10 @@ class Conv1dBatchNorm(nn.Layer):
|
||||||
bias_attr=bias_attr,
|
bias_attr=bias_attr,
|
||||||
data_format=data_format)
|
data_format=data_format)
|
||||||
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
||||||
self.bn = nn.BatchNorm1D(out_channels, data_format=data_format)
|
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bn(self.conv(x))
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,9 @@ class PositionwiseFFN(nn.Layer):
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(*, input_size), the output tensor.
|
Tensor: shape(*, input_size), the output tensor.
|
||||||
"""
|
"""
|
||||||
return self.linear2(self.dropout(F.relu(self.linear1(x))))
|
l1 = self.dropout(F.relu(self.linear1(x)))
|
||||||
|
l2 = self.linear2(l1)
|
||||||
|
return l2
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Layer):
|
class TransformerEncoderLayer(nn.Layer):
|
||||||
|
|
|
@ -3,6 +3,7 @@ import matplotlib
|
||||||
from matplotlib import cm, pyplot
|
from matplotlib import cm, pyplot
|
||||||
|
|
||||||
def pack_attention_images(attention_weights, rotate=False):
|
def pack_attention_images(attention_weights, rotate=False):
|
||||||
|
# add a box
|
||||||
attention_weights = np.pad(attention_weights,
|
attention_weights = np.pad(attention_weights,
|
||||||
[(0, 0), (1, 1), (1, 1)],
|
[(0, 0), (1, 1), (1, 1)],
|
||||||
mode="constant",
|
mode="constant",
|
||||||
|
@ -25,3 +26,6 @@ def pack_attention_images(attention_weights, rotate=False):
|
||||||
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def min_max_normalize(v):
|
||||||
|
return (v - v.min()) / (v.max() - v.min())
|
||||||
|
|
Loading…
Reference in New Issue