diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index a0029ed..3c65c8d 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -74,10 +74,11 @@ class MultiheadAttention(nn.Layer): q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) k = _split_heads(self.affine_k(k), self.num_heads) v = _split_heads(self.affine_v(v), self.num_heads) - mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim + if mask is not None: + mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim context_vectors, attention_weights = scaled_dot_product_attention( - q, k, v, mask) + q, k, v, mask, training=self.training) context_vectors = drop_head(context_vectors, drop_n_heads, self.training) context_vectors = _concat_heads(context_vectors) # (B, T, h*C) @@ -110,9 +111,11 @@ class TransformerEncoderLayer(nn.Layer): def _forward_mha(self, x, mask, drop_n_heads): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual + if mask is not None: + mask = paddle.unsqueeze(mask, 1) x_in = x x = self.layer_norm1(x) - context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads) + context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) return context_vector, attn_weights @@ -292,14 +295,14 @@ class MLPPreNet(nn.Layer): super(MLPPreNet, self).__init__() self.lin1 = nn.Linear(d_input, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden) - self.lin3 = nn.Linear(d_output, d_output) + # self.lin3 = nn.Linear(d_output, d_output) def forward(self, x, dropout): # the original code said also use dropout in inference l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training) l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training) - l3 = self.lin3(l2) - return l3 + #l3 = self.lin3(l2) + return l2 class CNNPreNet(nn.Layer): @@ -328,7 +331,6 @@ class CNNPreNet(nn.Layer): class CNNPostNet(nn.Layer): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers): super(CNNPostNet, self).__init__() - self.first_norm = nn.BatchNorm1D(d_output) self.convs = nn.LayerList() kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, ) padding = (kernel_size[0] - 1, 0) @@ -347,7 +349,6 @@ class CNNPostNet(nn.Layer): def forward(self, x): # why not use pre norms x_in = x - x = self.first_norm(x) for i, layer in enumerate(self.convs): x = layer(x) if i != (len(self.convs) - 1): @@ -362,8 +363,8 @@ class TransformerTTS(nn.Layer): postnet_kernel_size, max_reduction_factor, dropout): super(TransformerTTS, self).__init__() # encoder - self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05)) - self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout) + self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05)) + # self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout) self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout) @@ -381,6 +382,7 @@ class TransformerTTS(nn.Layer): self.padding_idx = padding_idx self.d_encoder = d_encoder self.d_decoder = d_decoder + self.d_mel = d_mel self.max_r = max_reduction_factor self.dropout = dropout @@ -403,7 +405,7 @@ class TransformerTTS(nn.Layer): def encode(self, text): T_enc = text.shape[-1] - embed = self.encoder_prenet(self.embedding(text)) + embed = self.encoder_prenet(text) if embed.shape[1] > self.encoder_pe.shape[0]: new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2) self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder) @@ -439,7 +441,8 @@ class TransformerTTS(nn.Layer): decoder_mask, self.drop_n_heads) - output_proj = self.final_proj(decoder_output) + # use only parts of it + output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim] mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim]) stop_logits = self.stop_conditioner(mel_intermediate) @@ -447,6 +450,7 @@ class TransformerTTS(nn.Layer): mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1]) mel_output = self.decoder_postnet(mel_channel_first) mel_output = paddle.transpose(mel_output, [0, 2, 1]) + return mel_output, mel_intermediate, cross_attention_weights, stop_logits def predict(self, input, max_length=1000, verbose=True): diff --git a/parakeet/modules/attention.py b/parakeet/modules/attention.py index dd630fe..d7053b4 100644 --- a/parakeet/modules/attention.py +++ b/parakeet/modules/attention.py @@ -27,7 +27,7 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True) scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d)) if mask is not None: - scaled_logit += paddle.scale((1.0 - mask), -1e12) # hard coded here + scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here attn_weights = F.softmax(scaled_logit, axis=-1) attn_weights = F.dropout(attn_weights, dropout, training=training) diff --git a/parakeet/modules/conv.py b/parakeet/modules/conv.py index c8f854c..698cda2 100644 --- a/parakeet/modules/conv.py +++ b/parakeet/modules/conv.py @@ -92,8 +92,10 @@ class Conv1dBatchNorm(nn.Layer): bias_attr=bias_attr, data_format=data_format) # TODO: channel last, but BatchNorm1d does not support channel last layout - self.bn = nn.BatchNorm1D(out_channels, data_format=data_format) + self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format) def forward(self, x): - return self.bn(self.conv(x)) + x = self.conv(x) + x = self.bn(x) + return x diff --git a/parakeet/modules/transformer.py b/parakeet/modules/transformer.py index d06b1d5..f262923 100644 --- a/parakeet/modules/transformer.py +++ b/parakeet/modules/transformer.py @@ -39,7 +39,9 @@ class PositionwiseFFN(nn.Layer): Returns: Tensor: shape(*, input_size), the output tensor. """ - return self.linear2(self.dropout(F.relu(self.linear1(x)))) + l1 = self.dropout(F.relu(self.linear1(x))) + l2 = self.linear2(l1) + return l2 class TransformerEncoderLayer(nn.Layer): diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py index e32aaa7..314578b 100644 --- a/parakeet/utils/display.py +++ b/parakeet/utils/display.py @@ -3,6 +3,7 @@ import matplotlib from matplotlib import cm, pyplot def pack_attention_images(attention_weights, rotate=False): + # add a box attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)], mode="constant", @@ -25,3 +26,6 @@ def pack_attention_images(attention_weights, rotate=False): img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)]) return img + +def min_max_normalize(v): + return (v - v.min()) / (v.max() - v.min())