add an adaptive loss to balance stop prediction classes

This commit is contained in:
chenfeiyu 2020-12-05 14:12:10 +08:00
parent a4a0bd8c98
commit d3761683e1
1 changed files with 28 additions and 0 deletions

View File

@ -535,3 +535,31 @@ class TransformerTTSLoss(nn.Layer):
stop_loss=stop_loss # stop prob loss stop_loss=stop_loss # stop prob loss
) )
return losses return losses
class AdaptiveTransformerTTSLoss(nn.Layer):
def __init__(self):
super(AdaptiveTransformerTTSLoss, self).__init__()
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
batch_size, mel_len = mask.shape
valid_lengths = mask.sum(-1).astype("int64")
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
stop_loss_scale = valid_lengths.sum() / batch_size - 1
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
loss = mel_loss1 + mel_loss2 + stop_loss
losses = dict(
loss=loss, # total loss
mel_loss1=mel_loss1, # ouput mel loss
mel_loss2=mel_loss2, # intermediate mel loss
stop_loss=stop_loss # stop prob loss
)
return losses