add an adaptive loss to balance stop prediction classes
This commit is contained in:
parent
a4a0bd8c98
commit
d3761683e1
|
@ -535,3 +535,31 @@ class TransformerTTSLoss(nn.Layer):
|
||||||
stop_loss=stop_loss # stop prob loss
|
stop_loss=stop_loss # stop prob loss
|
||||||
)
|
)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveTransformerTTSLoss(nn.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super(AdaptiveTransformerTTSLoss, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
||||||
|
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
||||||
|
mask1 = paddle.unsqueeze(mask, -1)
|
||||||
|
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||||
|
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||||
|
|
||||||
|
batch_size, mel_len = mask.shape
|
||||||
|
valid_lengths = mask.sum(-1).astype("int64")
|
||||||
|
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
||||||
|
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
||||||
|
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
|
||||||
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||||
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||||
|
|
||||||
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||||
|
losses = dict(
|
||||||
|
loss=loss, # total loss
|
||||||
|
mel_loss1=mel_loss1, # ouput mel loss
|
||||||
|
mel_loss2=mel_loss2, # intermediate mel loss
|
||||||
|
stop_loss=stop_loss # stop prob loss
|
||||||
|
)
|
||||||
|
return losses
|
Loading…
Reference in New Issue