diff --git a/parakeet/frontend/phonectic.py b/parakeet/frontend/phonectic.py index 97748fe..cda0fc7 100644 --- a/parakeet/frontend/phonectic.py +++ b/parakeet/frontend/phonectic.py @@ -44,6 +44,7 @@ class English(Phonetics): def __call__(self, sentence): return self.numericalize(self.phoneticize(sentence)) + @property def vocab_size(self): return len(self.vocab) @@ -88,6 +89,7 @@ class Chinese(Phonetics): def __call__(self, sentence): return self.numericalize(self.phoneticize(sentence)) + @property def vocab_size(self): return len(self.vocab) diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index 3c65c8d..6f1b62f 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -1,9 +1,11 @@ import math +from tqdm import trange import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +import parakeet from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention from parakeet.modules.transformer import PositionwiseFFN from parakeet.modules import masking @@ -111,8 +113,6 @@ class TransformerEncoderLayer(nn.Layer): def _forward_mha(self, x, mask, drop_n_heads): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual - if mask is not None: - mask = paddle.unsqueeze(mask, 1) x_in = x x = self.layer_norm1(x) context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) @@ -131,21 +131,13 @@ class TransformerEncoderLayer(nn.Layer): """ Args: x (Tensor): shape(batch_size, time_steps, d_model), the decoder input. - mask (Tensor): shape(batch_size, time_steps), the padding mask. + mask (Tensor): shape(batch_size, 1, time_steps), the padding mask. Returns: (x, attn_weights) x (Tensor): shape(batch_size, time_steps, d_model), the decoded. attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention. """ - # # pre norm - # x_in = x - # x = self.layer_norm1(x) - # context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1), drop_n_heads) - # x = x_in + context_vector # here, the order can be tuned - - # # pre norm - # x = x + self.ffn(self.layer_norm2(x)) x, attn_weights = self._forward_mha(x, mask, drop_n_heads) x = self._forward_ffn(x) return x, attn_weights @@ -188,7 +180,7 @@ class TransformerDecoderLayer(nn.Layer): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual q_in = q q = self.layer_norm2(q) - context_vector, attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(mask, 1), drop_n_heads) + context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads) context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training) return context_vector, attn_weights @@ -206,7 +198,7 @@ class TransformerDecoderLayer(nn.Layer): q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input. k (Tensor): shape(batch_size, time_steps_k, d_model), keys. v (Tensor): shape(batch_size, time_steps_k, d_model), values - encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask. + encoder_mask (Tensor): shape(batch_size, 1, time_steps_k) encoder padding mask. decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask. Returns: @@ -214,21 +206,7 @@ class TransformerDecoderLayer(nn.Layer): q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded. self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. - """ - # # pre norm - # q_in = q - # q = self.layer_norm1(q) - # context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask, drop_n_heads) - # q = q_in + context_vector - - # # pre norm - # q_in = q - # q = self.layer_norm2(q) - # context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1), drop_n_heads) - # q = q_in + context_vector - - # # pre norm - # q = q + self.ffn(self.layer_norm3(q)) + """ q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads) q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads) q = self._forward_ffn(q) @@ -245,7 +223,7 @@ class TransformerEncoder(nn.LayerList): """ Args: x (Tensor): shape(batch_size, time_steps, feature_size), the input tensor. - mask (Tensor): shape(batch_size, time_steps), the mask. + mask (Tensor): shape(batch_size, 1, time_steps), the mask. drop_n_heads (int, optional): how many heads to drop. Defaults to 0. Returns: @@ -273,7 +251,7 @@ class TransformerDecoder(nn.LayerList): q (Tensor): shape(batch_size, time_steps_q, d_model) k (Tensor): shape(batch_size, time_steps_k, d_encoder) v (Tensor): shape(batch_size, time_steps_k, k_encoder) - encoder_mask (Tensor): shape(batch_size, time_steps_k) + encoder_mask (Tensor): shape(batch_size, 1, time_steps_k) decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) drop_n_heads (int, optional): [description]. Defaults to 0. @@ -290,21 +268,21 @@ class TransformerDecoder(nn.LayerList): class MLPPreNet(nn.Layer): - def __init__(self, d_input, d_hidden, d_output): + def __init__(self, d_input, d_hidden, d_output, dropout): # (lin + relu + dropout) * n + last projection super(MLPPreNet, self).__init__() self.lin1 = nn.Linear(d_input, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden) - # self.lin3 = nn.Linear(d_output, d_output) + self.lin3 = nn.Linear(d_output, d_output) + self.dropout = dropout def forward(self, x, dropout): - # the original code said also use dropout in inference - l1 = F.dropout(F.relu(self.lin1(x)), dropout, training=self.training) - l2 = F.dropout(F.relu(self.lin2(l1)), dropout, training=self.training) - #l3 = self.lin3(l2) - return l2 - + l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training) + l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training) + l3 = self.lin3(l2) + return l3 +# NOTE: not used in class CNNPreNet(nn.Layer): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers, dropout=0.): @@ -347,7 +325,6 @@ class CNNPostNet(nn.Layer): # NOTE: it can also be a non-causal conv def forward(self, x): - # why not use pre norms x_in = x for i, layer in enumerate(self.convs): x = layer(x) @@ -358,33 +335,60 @@ class CNNPostNet(nn.Layer): class TransformerTTS(nn.Layer): - def __init__(self, vocab_size, padding_idx, d_encoder, d_decoder, d_mel, n_heads, d_ffn, - encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers, - postnet_kernel_size, max_reduction_factor, dropout): + def __init__(self, + frontend: parakeet.frontend.Phonetics, + d_encoder: int, + d_decoder: int, + d_mel: int, + n_heads: int, + d_ffn: int, + encoder_layers: int, + decoder_layers: int, + d_prenet: int, + d_postnet: int, + postnet_layers: int, + postnet_kernel_size: int, + max_reduction_factor: int, + decoder_prenet_dropout: float, + dropout: float): super(TransformerTTS, self).__init__() + + # text frontend (text normalization and g2p) + self.frontend = frontend + # encoder - self.encoder_prenet = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05)) - # self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout) - self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later - self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) - self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, encoder_layers, dropout) + self.encoder_prenet = nn.Embedding( + frontend.vocab_size, d_encoder, + padding_idx=frontend.vocab.padding_index, + weight_attr=I.Uniform(-0.05, 0.05)) + # position encoding matrix may be extended later + self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) + self.encoder_pe_scalar = self.create_parameter( + [1], attr=I.Constant(1.)) + self.encoder = TransformerEncoder( + d_encoder, n_heads, d_ffn, encoder_layers, dropout) # decoder - self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder) - self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) # it may be extended later - self.decoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.)) - self.decoder = TransformerDecoder(d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_encoder=d_encoder) + self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) + self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) + self.decoder_pe_scalar = self.create_parameter( + [1], attr=I.Constant(1.)) + self.decoder = TransformerDecoder( + d_decoder, n_heads, d_ffn, decoder_layers, dropout, + d_encoder=d_encoder) self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel) - self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) + self.decoder_postnet = CNNPostNet( + d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) self.stop_conditioner = nn.Linear(d_mel, 3) # specs - self.padding_idx = padding_idx + self.padding_idx = frontend.vocab.padding_index self.d_encoder = d_encoder self.d_decoder = d_decoder self.d_mel = d_mel self.max_r = max_reduction_factor self.dropout = dropout + self.decoder_prenet_dropout = decoder_prenet_dropout # start and end: though it is only used in predict # it can also be used in training @@ -395,13 +399,19 @@ class TransformerTTS(nn.Layer): # mutables self.r = max_reduction_factor # set it every call - self.decoder_prenet_dropout = 0.0 self.drop_n_heads = 0 - def forward(self, text, mel, stop): + def forward(self, text, mel): encoded, encoder_attention_weights, encoder_mask = self.encode(text) mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask) - return mel_output, mel_intermediate, encoder_attention_weights, cross_attention_weights, stop_logits + outputs = { + "mel_output": mel_output, + "mel_intermediate": mel_intermediate, + "encoder_attention_weights": encoder_attention_weights, + "cross_attention_weights": cross_attention_weights, + "stop_logits": stop_logits, + } + return outputs def encode(self, text): T_enc = text.shape[-1] @@ -414,7 +424,8 @@ class TransformerTTS(nn.Layer): x = F.dropout(x, self.dropout, training=self.training) # TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask - encoder_padding_mask = masking.id_mask(text, self.padding_idx, dtype=x.dtype) + encoder_padding_mask = paddle.unsqueeze( + masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1) x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads) return x, attention_weights, encoder_padding_mask @@ -453,21 +464,26 @@ class TransformerTTS(nn.Layer): return mel_output, mel_intermediate, cross_attention_weights, stop_logits - def predict(self, input, max_length=1000, verbose=True): - """[summary] + def predict(self, input, raw_input=True, max_length=1000, verbose=True): + """Predict log scale magnitude mel spectrogram from text input. Args: input (Tensor): shape (T), dtype int, input text sequencce. max_length (int, optional): max decoder steps. Defaults to 1000. verbose (bool, optional): display progress bar. Defaults to True. """ - text_input = paddle.unsqueeze(input, 0) # (1, T) + if raw_input: + text_ids = paddle.to_tensor(self.frontend(input)) + text_input = paddle.unsqueeze(text_ids, 0) # (1, T) + else: + text_input = input + decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) # encoder the text sequence encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input) - for _ in range(int(max_length // self.r) + 1): + for _ in trange(int(max_length // self.r) + 1): mel_output, _, cross_attention_weights, stop_logits = self.decode( encoder_output, decoder_input, encoder_padding_mask) @@ -477,19 +493,23 @@ class TransformerTTS(nn.Layer): decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) # stop condition? - if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index: + if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index): if verbose: print("Hits stop condition.") break + mel_output = decoder_output[:, 1:, :] - return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights + outputs = { + "mel_output": mel_output, + "encoder_attention_weights": encoder_attentions, + "cross_attention_weights": cross_attention_weights, + } + return outputs - def set_constants(self, reduction_factor, drop_n_heads, decoder_prenet_dropout): - # TODO(chenfeiyu): make a good design for these hyperparameter settings + def set_constants(self, reduction_factor, drop_n_heads): self.r = reduction_factor self.drop_n_heads = drop_n_heads - self.decoder_prenet_dropout = decoder_prenet_dropout - + class TransformerTTSLoss(nn.Layer): def __init__(self, stop_loss_scale): @@ -505,12 +525,14 @@ class TransformerTTSLoss(nn.Layer): mel_len = mask.shape[-1] last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len) mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype) - stop_loss = L.masked_softmax_with_cross_entropy(stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) + stop_loss = L.masked_softmax_with_cross_entropy( + stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) loss = mel_loss1 + mel_loss2 + stop_loss - details = dict( + losses = dict( + loss=loss, # total loss mel_loss1=mel_loss1, # ouput mel loss mel_loss2=mel_loss2, # intermediate mel loss stop_loss=stop_loss # stop prob loss ) - return loss, details \ No newline at end of file + return losses \ No newline at end of file diff --git a/parakeet/utils/__init__.py b/parakeet/utils/__init__.py index abf198b..9ef6d7a 100644 --- a/parakeet/utils/__init__.py +++ b/parakeet/utils/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from . import io, layer_tools, scheduler, display diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py index 314578b..2e25997 100644 --- a/parakeet/utils/display.py +++ b/parakeet/utils/display.py @@ -26,6 +26,14 @@ def pack_attention_images(attention_weights, rotate=False): img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)]) return img +def add_attention_plots(writer, tag, attention_weights, global_step): + attns = [attn[0].numpy() for attn in attention_weights] + for i, attn in enumerate(attns): + img = pack_attention_images(attn) + writer.add_image(f"{tag}/{i}", + cm.plasma(img), + global_step=global_step, + dataformats="HWC") def min_max_normalize(v): return (v - v.min()) / (v.max() - v.min()) diff --git a/parakeet/utils/io.py b/parakeet/utils/io.py index ed78bcc..7f89593 100644 --- a/parakeet/utils/io.py +++ b/parakeet/utils/io.py @@ -132,13 +132,13 @@ def load_parameters(model, k].dtype: model_dict[k] = v.astype(state_dict[k].numpy().dtype) - model.set_dict(model_dict) + model.set_state_dict(model_dict) print("[checkpoint] Rank {}: loaded model from {}.pdparams".format( local_rank, checkpoint_path)) if optimizer and optimizer_dict: - optimizer.set_dict(optimizer_dict) + optimizer.set_state_dict(optimizer_dict) print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt". format(local_rank, checkpoint_path)) diff --git a/parakeet/utils/mp_tools.py b/parakeet/utils/mp_tools.py new file mode 100644 index 0000000..bc24726 --- /dev/null +++ b/parakeet/utils/mp_tools.py @@ -0,0 +1,18 @@ +import paddle +from paddle import distributed as dist +from functools import wraps + +def rank_zero_only(func): + local_rank = dist.get_rank() + + @wraps(func) + def wrapper(*args, **kwargs): + if local_rank != 0: + return + result = func(*args, **kwargs) + return result + + return wrapper + + +