diff --git a/parakeet/models/__init__.py b/parakeet/models/__init__.py index d8521da..0a32a9d 100644 --- a/parakeet/models/__init__.py +++ b/parakeet/models/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parakeet.models.clarinet import * +#from parakeet.models.clarinet import * from parakeet.models.waveflow import * -from parakeet.models.wavenet import * +#from parakeet.models.wavenet import * from parakeet.models.transformer_tts import * -from parakeet.models.deepvoice3 import * +#from parakeet.models.deepvoice3 import * # from parakeet.models.fastspeech import * diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index e39404c..4cc3df3 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer): super(MLPPreNet, self).__init__() self.lin1 = nn.Linear(d_input, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden) + self.lin3 = nn.Linear(d_hidden, d_hidden) self.dropout = dropout def forward(self, x, dropout): l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training) l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training) - return l2 + l3 = self.lin3(l2) + return l3 # NOTE: not used in class CNNPreNet(nn.Layer): @@ -317,6 +319,7 @@ class CNNPostNet(nn.Layer): Conv1dBatchNorm(c_in, c_out, kernel_size, weight_attr=I.XavierUniform(), padding=padding)) + self.last_bn = nn.BatchNorm1D(d_output) # for a layer that ends with a normalization layer that is targeted to # output a non zero-central output, it may take a long time to # train the scale and bias @@ -328,7 +331,7 @@ class CNNPostNet(nn.Layer): x = layer(x) if i != (len(self.convs) - 1): x = F.tanh(x) - x = x_in + x + x = self.last_bn(x_in + x) return x @@ -491,7 +494,7 @@ class TransformerTTS(nn.Layer): decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) # stop condition: (if any ouput frame of the output multiframes hits the stop condition) - if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index): + if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index): if verbose: print("Hits stop condition.") break @@ -526,6 +529,34 @@ class TransformerTTSLoss(nn.Layer): stop_loss = L.masked_softmax_with_cross_entropy( stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) + loss = mel_loss1 + mel_loss2 + stop_loss + losses = dict( + loss=loss, # total loss + mel_loss1=mel_loss1, # ouput mel loss + mel_loss2=mel_loss2, # intermediate mel loss + stop_loss=stop_loss # stop prob loss + ) + return losses + + +class AdaptiveTransformerTTSLoss(nn.Layer): + def __init__(self): + super(AdaptiveTransformerTTSLoss, self).__init__() + + def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): + mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) + mask1 = paddle.unsqueeze(mask, -1) + mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) + mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) + + batch_size, mel_len = mask.shape + valid_lengths = mask.sum(-1).astype("int64") + last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len) + stop_loss_scale = valid_lengths.sum() / batch_size - 1 + mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype) + stop_loss = L.masked_softmax_with_cross_entropy( + stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) + loss = mel_loss1 + mel_loss2 + stop_loss losses = dict( loss=loss, # total loss diff --git a/parakeet/modules/losses.py b/parakeet/modules/losses.py index e7187a8..b8bc945 100644 --- a/parakeet/modules/losses.py +++ b/parakeet/modules/losses.py @@ -1,3 +1,5 @@ +import numba +import numpy as np import paddle from paddle import nn from paddle.nn import functional as F @@ -12,7 +14,7 @@ def weighted_mean(input, weight): Returns: Tensor: shape(1,), weighted mean tensor with the same dtype as input. """ - weight = paddle.cast(weight, input.dtype) + weight = paddle.cast(weight, input.dtype) return paddle.mean(input * weight) def masked_l1_loss(prediction, target, mask): @@ -22,3 +24,32 @@ def masked_l1_loss(prediction, target, mask): def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1): ce = F.softmax_with_cross_entropy(logits, label, axis=axis) return weighted_mean(ce, mask) + +def diagonal_loss(attentions, input_lengths, target_lengths, g=0.2, multihead=False): + """A metric to evaluate how diagonal a attention distribution is.""" + W = guided_attentions(input_lengths, target_lengths, g) + W_tensor = paddle.to_tensor(W) + if not multihead: + return paddle.mean(attentions * W_tensor) + else: + return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1)) + +@numba.jit(nopython=True) +def guided_attention(N, max_N, T, max_T, g): + W = np.zeros((max_T, max_N), dtype=np.float32) + for t in range(T): + for n in range(N): + W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g)) + # (T_dec, T_enc) + return W + +def guided_attentions(input_lengths, target_lengths, g=0.2): + B = len(input_lengths) + max_input_len = input_lengths.max() + max_target_len = target_lengths.max() + W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32) + for b in range(B): + W[b] = guided_attention(input_lengths[b], max_input_len, + target_lengths[b], max_target_len, g) + # (B, T_dec, T_enc) + return W \ No newline at end of file diff --git a/parakeet/modules/multihead_attention.py b/parakeet/modules/multihead_attention.py deleted file mode 100644 index 2d4792e..0000000 --- a/parakeet/modules/multihead_attention.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math -import numpy as np -import paddle.fluid as fluid -import paddle.fluid.dygraph as dg -import paddle.fluid.layers as layers - - -class Linear(dg.Layer): - def __init__(self, - in_features, - out_features, - is_bias=True, - dtype="float32"): - super(Linear, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.dtype = dtype - self.weight = fluid.ParamAttr( - initializer=fluid.initializer.XavierInitializer()) - self.bias = is_bias - - if is_bias is not False: - k = math.sqrt(1.0 / in_features) - self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform( - low=-k, high=k)) - - self.linear = dg.Linear( - in_features, - out_features, - param_attr=self.weight, - bias_attr=self.bias, ) - - def forward(self, x): - x = self.linear(x) - return x - - -class ScaledDotProductAttention(dg.Layer): - def __init__(self, d_key): - """Scaled dot product attention module. - - Args: - d_key (int): the dim of key in multihead attention. - """ - super(ScaledDotProductAttention, self).__init__() - - self.d_key = d_key - - # please attention this mask is diff from pytorch - def forward(self, - key, - value, - query, - mask=None, - query_mask=None, - dropout=0.1): - """ - Compute scaled dot product attention. - - Args: - key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention. - value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention. - query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention. - mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None. - query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None. - dropout (float32, optional): the probability of dropout. Defaults to 0.1. - Returns: - result (Variable): shape(B, T, C), the result of mutihead attention. - attention (Variable): shape(n_head * B, T, C), the attention of key. - """ - # Compute attention score - attention = layers.matmul( - query, key, transpose_y=True, alpha=self.d_key - **-0.5) #transpose the last dim in y - - # Mask key to ignore padding - if mask is not None: - attention = attention + mask - attention = layers.softmax(attention, use_cudnn=True) - attention = layers.dropout( - attention, dropout, dropout_implementation='upscale_in_train') - - # Mask query to ignore padding - if query_mask is not None: - attention = attention * query_mask - - result = layers.matmul(attention, value) - return result, attention - - -class MultiheadAttention(dg.Layer): - def __init__(self, - num_hidden, - d_k, - d_q, - num_head=4, - is_bias=False, - dropout=0.1, - is_concat=True): - """Multihead Attention. - - Args: - num_hidden (int): the number of hidden layer in network. - d_k (int): the dim of key in multihead attention. - d_q (int): the dim of query in multihead attention. - num_head (int, optional): the head number of multihead attention. Defaults to 4. - is_bias (bool, optional): whether have bias in linear layers. Default to False. - dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. - is_concat (bool, optional): whether concat query and result. Default to True. - """ - super(MultiheadAttention, self).__init__() - self.num_hidden = num_hidden - self.num_head = num_head - self.d_k = d_k - self.d_q = d_q - self.dropout = dropout - self.is_concat = is_concat - - self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias) - self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias) - self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias) - - self.scal_attn = ScaledDotProductAttention(d_k) - - if self.is_concat: - self.fc = Linear(num_head * d_q * 2, num_hidden) - else: - self.fc = Linear(num_head * d_q, num_hidden) - - self.layer_norm = dg.LayerNorm(num_hidden) - - def forward(self, key, value, query_input, mask=None, query_mask=None): - """ - Compute attention. - - Args: - key (Variable): shape(B, T, C), dtype float32, the input key of attention. - value (Variable): shape(B, T, C), dtype float32, the input value of attention. - query_input (Variable): shape(B, T, C), dtype float32, the input query of attention. - mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None. - query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None. - - Returns: - result (Variable): shape(B, T, C), the result of mutihead attention. - attention (Variable): shape(num_head * B, T, C), the attention of key and query. - """ - - batch_size = key.shape[0] - seq_len_key = key.shape[1] - seq_len_query = query_input.shape[1] - - # Make multihead attention - key = layers.reshape( - self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k]) - value = layers.reshape( - self.value(value), - [batch_size, seq_len_key, self.num_head, self.d_k]) - query = layers.reshape( - self.query(query_input), - [batch_size, seq_len_query, self.num_head, self.d_q]) - - key = layers.reshape( - layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k]) - value = layers.reshape( - layers.transpose(value, [2, 0, 1, 3]), - [-1, seq_len_key, self.d_k]) - query = layers.reshape( - layers.transpose(query, [2, 0, 1, 3]), - [-1, seq_len_query, self.d_q]) - - result, attention = self.scal_attn( - key, value, query, mask=mask, query_mask=query_mask) - - # concat all multihead result - result = layers.reshape( - result, [self.num_head, batch_size, seq_len_query, self.d_q]) - result = layers.reshape( - layers.transpose(result, [1, 2, 0, 3]), - [batch_size, seq_len_query, -1]) - if self.is_concat: - result = layers.concat([query_input, result], axis=-1) - result = layers.dropout( - self.fc(result), - self.dropout, - dropout_implementation='upscale_in_train') - result = result + query_input - - result = self.layer_norm(result) - return result, attention diff --git a/parakeet/training/cli.py b/parakeet/training/cli.py new file mode 100644 index 0000000..800c3a3 --- /dev/null +++ b/parakeet/training/cli.py @@ -0,0 +1,21 @@ +import argparse + +def default_argument_parser(): + parser = argparse.ArgumentParser() + + # data and outpu + parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") + parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") + parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and log. If not provided, a directory is created in runs/ to save outputs.") + + # load from saved checkpoint + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") + + # running + parser.add_argument("--device", type=str, choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") + parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") + + # overwrite extra config and default config + parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + + return parser diff --git a/parakeet/utils/checkpoint.py b/parakeet/utils/checkpoint.py new file mode 100644 index 0000000..ce07639 --- /dev/null +++ b/parakeet/utils/checkpoint.py @@ -0,0 +1,137 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import numpy as np +import paddle +from paddle import distributed as dist +from parakeet.utils import mp_tools + + +def _load_latest_checkpoint(checkpoint_dir): + """Get the iteration number corresponding to the latest saved checkpoint + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + + Returns: + int: the latest iteration number. + """ + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") + # Create checkpoint index file if not exist. + if (not os.path.isfile(checkpoint_record)): + return 0 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "r") as handle: + latest_checkpoint = handle.readline().split()[-1] + iteration = int(latest_checkpoint.split("-")[-1]) + + return iteration + +def _save_checkpoint(checkpoint_dir, iteration): + """Save the iteration number of the latest model to be checkpointed. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + + Returns: + None + """ + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") + # Update the latest checkpoint index. + with open(checkpoint_record, "w") as handle: + handle.write("model_checkpoint_path: step-{}".format(iteration)) + +def load_parameters(model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a specific model checkpoint from disk. + + Args: + model (obj): model to load parameters. + optimizer (obj, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + + Returns: + iteration (int): number of iterations that the loaded checkpoint has + been trained. + """ + if checkpoint_path is not None: + iteration = int(os.path.basename(checkpoint_path).split("-")[-1]) + elif checkpoint_dir is not None: + iteration = _load_latest_checkpoint(checkpoint_dir) + if iteration == 0: + return iteration + checkpoint_path = os.path.join(checkpoint_dir, + "step-{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + ) + + local_rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + print("[checkpoint] Rank {}: loaded model from {}".format( + local_rank, params_path)) + + optimizer_path = checkpoint_path + ".pdopt" + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + print("[checkpoint] Rank {}: loaded optimizer state from {}". + format(local_rank, optimizer_path)) + + return iteration + +@mp_tools.rank_zero_only +def save_parameters(checkpoint_dir, iteration, model, optimizer=None): + """Checkpoint the latest trained model parameters. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + model (obj): model to be checkpointed. + optimizer (obj, optional): optimizer to be checkpointed. + Defaults to None. + + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + print("[checkpoint] Saved model to {}".format(params_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + print("[checkpoint] Saved optimzier state to {}".format( + optimizer_path)) + + _save_checkpoint(checkpoint_dir, iteration)