158 lines
7.2 KiB
Python
158 lines
7.2 KiB
Python
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
from paddle import distribution as D
|
|
|
|
from parakeet.models.wavenet import WaveNet, UpsampleNet, crop
|
|
|
|
__all__ = ["Clarinet"]
|
|
|
|
class ParallelWaveNet(nn.LayerList):
|
|
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
|
|
filter_size):
|
|
"""ParallelWaveNet, an inverse autoregressive flow model, it contains several flows(WaveNets).
|
|
|
|
Args:
|
|
n_loops (List[int]): `n_loop` for each flow.
|
|
n_layers (List[int]): `n_layer` for each flow.
|
|
residual_channels (int): `residual_channels` for every flow.
|
|
condition_dim (int): `condition_dim` for every flow.
|
|
filter_size (int): `filter_size` for every flow.
|
|
"""
|
|
super(ParallelWaveNet, self).__init__()
|
|
for n_loop, n_layer in zip(n_loops, n_layers):
|
|
# teacher's log_scale_min does not matter herem, -100 is a dummy value
|
|
self.append(
|
|
WaveNet(n_loop, n_layer, residual_channels, 3, condition_dim,
|
|
filter_size, "mog", -100.0))
|
|
|
|
def forward(self, z, condition=None):
|
|
"""Transform a random noise sampled from a standard Gaussian distribution into sample from the target distribution. And output the mean and log standard deviation of the output distribution.
|
|
|
|
Args:
|
|
z (Variable): shape(B, T), random noise sampled from a standard gaussian disribution.
|
|
condition (Variable, optional): shape(B, F, T), dtype float, the upsampled condition. Defaults to None.
|
|
|
|
Returns:
|
|
(z, out_mu, out_log_std)
|
|
z (Variable): shape(B, T), dtype float, transformed noise, it is the synthesized waveform.
|
|
out_mu (Variable): shape(B, T), dtype float, means of the output distributions.
|
|
out_log_std (Variable): shape(B, T), dtype float, log standard deviations of the output distributions.
|
|
"""
|
|
for i, flow in enumerate(self):
|
|
theta = flow(z, condition) # w, mu, log_std [0: T]
|
|
w, mu, log_std = paddle.chunk(theta, 3, axis=-1) # (B, T, 1) for each
|
|
mu = paddle.squeeze(mu, -1) #[0: T]
|
|
log_std = paddle.squeeze(log_std, -1) #[0: T]
|
|
z = z * paddle.exp(log_std) + mu #[0: T]
|
|
|
|
if i == 0:
|
|
out_mu = mu
|
|
out_log_std = log_std
|
|
else:
|
|
out_mu = out_mu * paddle.exp(log_std) + mu
|
|
out_log_std += log_std
|
|
|
|
return z, out_mu, out_log_std
|
|
|
|
|
|
# Gaussian IAF model
|
|
class Clarinet(nn.Layer):
|
|
def __init__(self, encoder, teacher, student, stft,
|
|
min_log_scale=-6.0, lmd=4.0):
|
|
"""Clarinet model. Conditional Parallel WaveNet.
|
|
|
|
Args:
|
|
encoder (UpsampleNet): an UpsampleNet to upsample mel spectrogram.
|
|
teacher (WaveNet): a WaveNet, the teacher.
|
|
student (ParallelWaveNet): a ParallelWaveNet model, the student.
|
|
stft (STFT): a STFT model to perform differentiable stft transform.
|
|
min_log_scale (float, optional): used only for computing loss, the minimal value of log standard deviation of the output distribution of both the teacher and the student . Defaults to -6.0.
|
|
lmd (float, optional): weight for stft loss. Defaults to 4.0.
|
|
"""
|
|
super(Clarinet, self).__init__()
|
|
self.encoder = encoder
|
|
self.teacher = teacher
|
|
self.student = student
|
|
self.stft = stft
|
|
|
|
self.lmd = lmd
|
|
self.min_log_scale = min_log_scale
|
|
|
|
def forward(self, audio, mel, audio_start, clip_kl=True):
|
|
"""Compute loss of Clarinet model.
|
|
|
|
Args:
|
|
audio (Variable): shape(B, T_audio), dtype flaot32, ground truth waveform.
|
|
mel (Variable): shape(B, F, T_mel), dtype flaot32, condition(mel spectrogram here).
|
|
audio_start (Variable): shape(B, ), dtype int64, audio starts positions.
|
|
clip_kl (bool, optional): whether to clip kl_loss by maximum=100. Defaults to True.
|
|
|
|
Returns:
|
|
Dict(str, Variable)
|
|
loss (Variable): shape(1, ), dtype flaot32, total loss.
|
|
kl (Variable): shape(1, ), dtype flaot32, kl divergence between the teacher's output distribution and student's output distribution.
|
|
regularization (Variable): shape(1, ), dtype flaot32, a regularization term of the KL divergence.
|
|
spectrogram_frame_loss (Variable): shape(1, ), dytpe: float, stft loss, the L1-distance of the magnitudes of the spectrograms of the ground truth waveform and synthesized waveform.
|
|
"""
|
|
batch_size, audio_length = audio.shape # audio clip's length
|
|
|
|
z = paddle.randn(audio.shape)
|
|
condition = self.encoder(mel) # (B, C, T)
|
|
condition_slice = crop(condition, audio_start, audio_length)
|
|
|
|
x, s_means, s_scales = self.student(z, condition_slice) # all [0: T]
|
|
s_means = s_means[:, 1:] # (B, T-1), time steps [1: T]
|
|
s_scales = s_scales[:, 1:] # (B, T-1), time steps [1: T]
|
|
s_clipped_scales = paddle.clip(s_scales, self.min_log_scale, 100.)
|
|
|
|
# teacher outputs single gaussian
|
|
y = self.teacher(x[:, :-1], condition_slice[:, :, 1:])
|
|
_, t_means, t_scales = paddle.chunk(y, 3, axis=-1) # time steps [1: T]
|
|
t_means = paddle.squeeze(t_means, [-1]) # (B, T-1), time steps [1: T]
|
|
t_scales = paddle.squeeze(t_scales, [-1]) # (B, T-1), time steps [1: T]
|
|
t_clipped_scales = paddle.clip(t_scales, self.min_log_scale, 100.)
|
|
|
|
s_distribution = D.Normal(s_means, paddle.exp(s_clipped_scales))
|
|
t_distribution = D.Normal(t_means, paddle.exp(t_clipped_scales))
|
|
|
|
# kl divergence loss, so we only need to sample once? no MC
|
|
kl = s_distribution.kl_divergence(t_distribution)
|
|
if clip_kl:
|
|
kl = paddle.clip(kl, -100., 10.)
|
|
# context size dropped
|
|
kl = paddle.reduce_mean(kl[:, self.teacher.context_size:])
|
|
# major diff here
|
|
regularization = F.mse_loss(t_scales[:, self.teacher.context_size:],
|
|
s_scales[:, self.teacher.context_size:])
|
|
|
|
# introduce information from real target
|
|
spectrogram_frame_loss = F.mse_loss(
|
|
self.stft.magnitude(audio), self.stft.magnitude(x))
|
|
loss = kl + self.lmd * regularization + spectrogram_frame_loss
|
|
loss_dict = {
|
|
"loss": loss,
|
|
"kl_divergence": kl,
|
|
"regularization": regularization,
|
|
"stft_loss": spectrogram_frame_loss
|
|
}
|
|
return loss_dict
|
|
|
|
@paddle.no_grad()
|
|
def synthesis(self, mel):
|
|
"""Synthesize waveform using the encoder and the student network.
|
|
|
|
Args:
|
|
mel (Variable): shape(B, F, T_mel), the condition(mel spectrogram here).
|
|
|
|
Returns:
|
|
Variable: shape(B, T_audio), the synthesized waveform. (T_audio = T_mel * upscale_factor, where upscale_factor is the `upscale_factor` of the encoder.)
|
|
"""
|
|
condition = self.encoder(mel)
|
|
samples_shape = (condition.shape[0], condition.shape[-1])
|
|
z = paddle.randn(samples_shape)
|
|
x, s_means, s_scales = self.student(z, condition)
|
|
return x
|
|
|
|
|
|
# TODO(chenfeiyu): ClariNetLoss |