From d3761683e1ad9cad3915160f10032aae1fb4edd2 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Sat, 5 Dec 2020 14:12:10 +0800 Subject: [PATCH] add an adaptive loss to balance stop prediction classes --- parakeet/models/transformer_tts.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index fb4c72a..d49a199 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -527,6 +527,34 @@ class TransformerTTSLoss(nn.Layer): stop_loss = L.masked_softmax_with_cross_entropy( stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) + loss = mel_loss1 + mel_loss2 + stop_loss + losses = dict( + loss=loss, # total loss + mel_loss1=mel_loss1, # ouput mel loss + mel_loss2=mel_loss2, # intermediate mel loss + stop_loss=stop_loss # stop prob loss + ) + return losses + + +class AdaptiveTransformerTTSLoss(nn.Layer): + def __init__(self): + super(AdaptiveTransformerTTSLoss, self).__init__() + + def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): + mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) + mask1 = paddle.unsqueeze(mask, -1) + mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) + mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) + + batch_size, mel_len = mask.shape + valid_lengths = mask.sum(-1).astype("int64") + last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len) + stop_loss_scale = valid_lengths.sum() / batch_size - 1 + mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype) + stop_loss = L.masked_softmax_with_cross_entropy( + stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) + loss = mel_loss1 + mel_loss2 + stop_loss losses = dict( loss=loss, # total loss