292 lines
12 KiB
Python
292 lines
12 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import division
|
|
import numpy as np
|
|
from numba import jit
|
|
|
|
from paddle import fluid
|
|
import paddle.fluid.layers as F
|
|
import paddle.fluid.dygraph as dg
|
|
|
|
|
|
def masked_mean(inputs, mask):
|
|
"""
|
|
Args:
|
|
inputs (Variable): shape(B, T, C), dtype float32, the input.
|
|
mask (Variable): shape(B, T), dtype float32, a mask.
|
|
Returns:
|
|
loss (Variable): shape(1, ), dtype float32, masked mean.
|
|
"""
|
|
channels = inputs.shape[-1]
|
|
masked_inputs = F.elementwise_mul(inputs, mask, axis=0)
|
|
loss = F.reduce_sum(masked_inputs) / (channels * F.reduce_sum(mask))
|
|
return loss
|
|
|
|
|
|
@jit(nopython=True)
|
|
def guided_attention(N, max_N, T, max_T, g):
|
|
"""Generate an diagonal attention guide.
|
|
|
|
Args:
|
|
N (int): valid length of encoder.
|
|
max_N (int): max length of encoder.
|
|
T (int): valid length of decoder.
|
|
max_T (int): max length of decoder.
|
|
g (float): sigma to adjust the degree of diagonal guide.
|
|
|
|
Returns:
|
|
np.ndarray: shape(max_N, max_T), dtype float32, the diagonal guide.
|
|
"""
|
|
W = np.zeros((max_N, max_T), dtype=np.float32)
|
|
for n in range(N):
|
|
for t in range(T):
|
|
W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
|
return W
|
|
|
|
|
|
def guided_attentions(encoder_lengths, decoder_lengths, max_decoder_len,
|
|
g=0.2):
|
|
"""Generate a diagonal attention guide for a batch.
|
|
|
|
Args:
|
|
encoder_lengths (np.ndarray): shape(B, ), dtype: int64, encoder valid lengths.
|
|
decoder_lengths (np.ndarray): shape(B, ), dtype: int64, decoder valid lengths.
|
|
max_decoder_len (int): max length of decoder.
|
|
g (float, optional): sigma to adjust the degree of diagonal guide.. Defaults to 0.2.
|
|
|
|
Returns:
|
|
np.ndarray: shape(B, max_T, max_N), dtype float32, the diagonal guide. (max_N: max encoder length, max_T: max decoder length.)
|
|
"""
|
|
B = len(encoder_lengths)
|
|
max_input_len = encoder_lengths.max()
|
|
W = np.zeros((B, max_decoder_len, max_input_len), dtype=np.float32)
|
|
for b in range(B):
|
|
W[b] = guided_attention(encoder_lengths[b], max_input_len,
|
|
decoder_lengths[b], max_decoder_len, g).T
|
|
return W
|
|
|
|
|
|
class TTSLoss(object):
|
|
def __init__(self,
|
|
masked_weight=0.0,
|
|
priority_bin=None,
|
|
priority_weight=0.0,
|
|
binary_divergence_weight=0.0,
|
|
guided_attention_sigma=0.2,
|
|
downsample_factor=4,
|
|
r=1):
|
|
"""Compute loss for Deep Voice 3 model.
|
|
|
|
Args:
|
|
masked_weight (float, optional): the weight of masked loss. Defaults to 0.0.
|
|
priority_bin ([type], optional): frequency bands for linear spectrogram loss to be prioritized. Defaults to None.
|
|
priority_weight (float, optional): weight for the prioritized frequency bands. Defaults to 0.0.
|
|
binary_divergence_weight (float, optional): weight for binary cross entropy (used for spectrogram loss). Defaults to 0.0.
|
|
guided_attention_sigma (float, optional): `sigma` for attention guide. Defaults to 0.2.
|
|
downsample_factor (int, optional): the downsample factor for mel spectrogram. Defaults to 4.
|
|
r (int, optional): frames per decoder step. Defaults to 1.
|
|
"""
|
|
self.masked_weight = masked_weight
|
|
self.priority_bin = priority_bin # only used for lin-spec loss
|
|
self.priority_weight = priority_weight # only used for lin-spec loss
|
|
self.binary_divergence_weight = binary_divergence_weight
|
|
self.guided_attention_sigma = guided_attention_sigma
|
|
|
|
self.time_shift = r
|
|
self.r = r
|
|
self.downsample_factor = downsample_factor
|
|
|
|
def l1_loss(self, prediction, target, mask, priority_bin=None):
|
|
"""L1 loss for spectrogram.
|
|
|
|
Args:
|
|
prediction (Variable): shape(B, T, C), dtype float32, predicted spectrogram.
|
|
target (Variable): shape(B, T, C), dtype float32, target spectrogram.
|
|
mask (Variable): shape(B, T), mask.
|
|
priority_bin (int, optional): frequency bands for linear spectrogram loss to be prioritized. Defaults to None.
|
|
|
|
Returns:
|
|
Variable: shape(1,), dtype float32, l1 loss(with mask and possibly priority bin applied.)
|
|
"""
|
|
abs_diff = F.abs(prediction - target)
|
|
|
|
# basic mask-weighted l1 loss
|
|
w = self.masked_weight
|
|
if w > 0 and mask is not None:
|
|
base_l1_loss = w * masked_mean(abs_diff, mask) \
|
|
+ (1 - w) * F.reduce_mean(abs_diff)
|
|
else:
|
|
base_l1_loss = F.reduce_mean(abs_diff)
|
|
|
|
if self.priority_weight > 0 and priority_bin is not None:
|
|
# mask-weighted priority channels' l1-loss
|
|
priority_abs_diff = abs_diff[:, :, :priority_bin]
|
|
if w > 0 and mask is not None:
|
|
priority_loss = w * masked_mean(priority_abs_diff, mask) \
|
|
+ (1 - w) * F.reduce_mean(priority_abs_diff)
|
|
else:
|
|
priority_loss = F.reduce_mean(priority_abs_diff)
|
|
|
|
# priority weighted sum
|
|
p = self.priority_weight
|
|
loss = p * priority_loss + (1 - p) * base_l1_loss
|
|
else:
|
|
loss = base_l1_loss
|
|
return loss
|
|
|
|
def binary_divergence(self, prediction, target, mask):
|
|
"""Binary cross entropy loss for spectrogram. All the values in the spectrogram are treated as logits in a logistic regression.
|
|
|
|
Args:
|
|
prediction (Variable): shape(B, T, C), dtype float32, predicted spectrogram.
|
|
target (Variable): shape(B, T, C), dtype float32, target spectrogram.
|
|
mask (Variable): shape(B, T), mask.
|
|
|
|
Returns:
|
|
Variable: shape(1,), dtype float32, binary cross entropy loss.
|
|
"""
|
|
flattened_prediction = F.reshape(prediction, [-1, 1])
|
|
flattened_target = F.reshape(target, [-1, 1])
|
|
flattened_loss = F.log_loss(
|
|
flattened_prediction, flattened_target, epsilon=1e-8)
|
|
bin_div = fluid.layers.reshape(flattened_loss, prediction.shape)
|
|
|
|
w = self.masked_weight
|
|
if w > 0 and mask is not None:
|
|
loss = w * masked_mean(bin_div, mask) \
|
|
+ (1 - w) * F.reduce_mean(bin_div)
|
|
else:
|
|
loss = F.reduce_mean(bin_div)
|
|
return loss
|
|
|
|
@staticmethod
|
|
def done_loss(done_hat, done):
|
|
"""Compute done loss
|
|
|
|
Args:
|
|
done_hat (Variable): shape(B, T), dtype float32, predicted done probability(the probability that the final frame has been generated.)
|
|
done (Variable): shape(B, T), dtype float32, ground truth done probability(the probability that the final frame has been generated.)
|
|
|
|
Returns:
|
|
Variable: shape(1, ), dtype float32, done loss.
|
|
"""
|
|
flat_done_hat = F.reshape(done_hat, [-1, 1])
|
|
flat_done = F.reshape(done, [-1, 1])
|
|
loss = F.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
|
|
loss = F.reduce_mean(loss)
|
|
return loss
|
|
|
|
def attention_loss(self, predicted_attention, input_lengths,
|
|
target_lengths):
|
|
"""
|
|
Given valid encoder_lengths and decoder_lengths, compute a diagonal guide, and compute loss from the predicted attention and the guide.
|
|
|
|
Args:
|
|
predicted_attention (Variable): shape(*, B, T_dec, T_enc), dtype float32, the alignment tensor, where B means batch size, T_dec means number of time steps of the decoder, T_enc means the number of time steps of the encoder, * means other possible dimensions.
|
|
input_lengths (numpy.ndarray): shape(B,), dtype:int64, valid lengths (time steps) of encoder outputs.
|
|
target_lengths (numpy.ndarray): shape(batch_size,), dtype:int64, valid lengths (time steps) of decoder outputs.
|
|
|
|
Returns:
|
|
loss (Variable): shape(1, ), dtype float32, attention loss.
|
|
"""
|
|
n_attention, batch_size, max_target_len, max_input_len = (
|
|
predicted_attention.shape)
|
|
soft_mask = guided_attentions(input_lengths, target_lengths,
|
|
max_target_len,
|
|
self.guided_attention_sigma)
|
|
soft_mask_ = dg.to_variable(soft_mask)
|
|
loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
|
|
return loss
|
|
|
|
def __call__(self, outputs, inputs):
|
|
"""Total loss
|
|
|
|
Args:
|
|
outpus is a tuple of (mel_hyp, lin_hyp, attn_hyp, done_hyp).
|
|
mel_hyp (Variable): shape(B, T, C_mel), dtype float32, predicted mel spectrogram.
|
|
lin_hyp (Variable): shape(B, T, C_lin), dtype float32, predicted linear spectrogram.
|
|
done_hyp (Variable): shape(B, T), dtype float32, predicted done probability.
|
|
attn_hyp (Variable): shape(N, B, T_dec, T_enc), dtype float32, predicted attention.
|
|
|
|
inputs is a tuple of (mel_ref, lin_ref, done_ref, input_lengths, n_frames)
|
|
mel_ref (Variable): shape(B, T, C_mel), dtype float32, ground truth mel spectrogram.
|
|
lin_ref (Variable): shape(B, T, C_lin), dtype float32, ground truth linear spectrogram.
|
|
done_ref (Variable): shape(B, T), dtype float32, ground truth done flag.
|
|
input_lengths (Variable): shape(B, ), dtype: int, encoder valid lengths.
|
|
n_frames (Variable): shape(B, ), dtype: int, decoder valid lengths.
|
|
|
|
Returns:
|
|
Dict(str, Variable): details of loss.
|
|
"""
|
|
total_loss = 0.
|
|
|
|
mel_hyp, lin_hyp, attn_hyp, done_hyp = outputs
|
|
mel_ref, lin_ref, done_ref, input_lengths, n_frames = inputs
|
|
|
|
# n_frames # mel_lengths # decoder_lengths
|
|
max_frames = lin_hyp.shape[1]
|
|
max_mel_steps = max_frames // self.downsample_factor
|
|
# max_decoder_steps = max_mel_steps // self.r
|
|
# decoder_mask = F.sequence_mask(n_frames // self.downsample_factor //
|
|
# self.r,
|
|
# max_decoder_steps,
|
|
# dtype="float32")
|
|
mel_mask = F.sequence_mask(
|
|
n_frames // self.downsample_factor, max_mel_steps, dtype="float32")
|
|
lin_mask = F.sequence_mask(n_frames, max_frames, dtype="float32")
|
|
|
|
lin_hyp = lin_hyp[:, :-self.time_shift, :]
|
|
lin_ref = lin_ref[:, self.time_shift:, :]
|
|
lin_mask = lin_mask[:, self.time_shift:]
|
|
lin_l1_loss = self.l1_loss(
|
|
lin_hyp, lin_ref, lin_mask, priority_bin=self.priority_bin)
|
|
lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
|
|
lin_loss = self.binary_divergence_weight * lin_bce_loss \
|
|
+ (1 - self.binary_divergence_weight) * lin_l1_loss
|
|
total_loss += lin_loss
|
|
|
|
mel_hyp = mel_hyp[:, :-self.time_shift, :]
|
|
mel_ref = mel_ref[:, self.time_shift:, :]
|
|
mel_mask = mel_mask[:, self.time_shift:]
|
|
mel_l1_loss = self.l1_loss(mel_hyp, mel_ref, mel_mask)
|
|
mel_bce_loss = self.binary_divergence(mel_hyp, mel_ref, mel_mask)
|
|
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
|
|
mel_loss = self.binary_divergence_weight * mel_bce_loss \
|
|
+ (1 - self.binary_divergence_weight) * mel_l1_loss
|
|
total_loss += mel_loss
|
|
|
|
attn_loss = self.attention_loss(attn_hyp,
|
|
input_lengths.numpy(),
|
|
n_frames.numpy() //
|
|
(self.downsample_factor * self.r))
|
|
total_loss += attn_loss
|
|
|
|
done_loss = self.done_loss(done_hyp, done_ref)
|
|
total_loss += done_loss
|
|
|
|
losses = {
|
|
"loss": total_loss,
|
|
"mel/mel_loss": mel_loss,
|
|
"mel/l1_loss": mel_l1_loss,
|
|
"mel/bce_loss": mel_bce_loss,
|
|
"lin/lin_loss": lin_loss,
|
|
"lin/l1_loss": lin_l1_loss,
|
|
"lin/bce_loss": lin_bce_loss,
|
|
"done": done_loss,
|
|
"attn": attn_loss,
|
|
}
|
|
|
|
return losses
|