ParakeetRebeccaRosario/parakeet/models/clarinet.py

158 lines
7.2 KiB
Python

import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle import distribution as D
from parakeet.models.wavenet import WaveNet, UpsampleNet, crop
__all__ = ["Clarinet"]
class ParallelWaveNet(nn.LayerList):
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
filter_size):
"""ParallelWaveNet, an inverse autoregressive flow model, it contains several flows(WaveNets).
Args:
n_loops (List[int]): `n_loop` for each flow.
n_layers (List[int]): `n_layer` for each flow.
residual_channels (int): `residual_channels` for every flow.
condition_dim (int): `condition_dim` for every flow.
filter_size (int): `filter_size` for every flow.
"""
super(ParallelWaveNet, self).__init__()
for n_loop, n_layer in zip(n_loops, n_layers):
# teacher's log_scale_min does not matter herem, -100 is a dummy value
self.append(
WaveNet(n_loop, n_layer, residual_channels, 3, condition_dim,
filter_size, "mog", -100.0))
def forward(self, z, condition=None):
"""Transform a random noise sampled from a standard Gaussian distribution into sample from the target distribution. And output the mean and log standard deviation of the output distribution.
Args:
z (Variable): shape(B, T), random noise sampled from a standard gaussian disribution.
condition (Variable, optional): shape(B, F, T), dtype float, the upsampled condition. Defaults to None.
Returns:
(z, out_mu, out_log_std)
z (Variable): shape(B, T), dtype float, transformed noise, it is the synthesized waveform.
out_mu (Variable): shape(B, T), dtype float, means of the output distributions.
out_log_std (Variable): shape(B, T), dtype float, log standard deviations of the output distributions.
"""
for i, flow in enumerate(self):
theta = flow(z, condition) # w, mu, log_std [0: T]
w, mu, log_std = paddle.chunk(theta, 3, axis=-1) # (B, T, 1) for each
mu = paddle.squeeze(mu, -1) #[0: T]
log_std = paddle.squeeze(log_std, -1) #[0: T]
z = z * paddle.exp(log_std) + mu #[0: T]
if i == 0:
out_mu = mu
out_log_std = log_std
else:
out_mu = out_mu * paddle.exp(log_std) + mu
out_log_std += log_std
return z, out_mu, out_log_std
# Gaussian IAF model
class Clarinet(nn.Layer):
def __init__(self, encoder, teacher, student, stft,
min_log_scale=-6.0, lmd=4.0):
"""Clarinet model. Conditional Parallel WaveNet.
Args:
encoder (UpsampleNet): an UpsampleNet to upsample mel spectrogram.
teacher (WaveNet): a WaveNet, the teacher.
student (ParallelWaveNet): a ParallelWaveNet model, the student.
stft (STFT): a STFT model to perform differentiable stft transform.
min_log_scale (float, optional): used only for computing loss, the minimal value of log standard deviation of the output distribution of both the teacher and the student . Defaults to -6.0.
lmd (float, optional): weight for stft loss. Defaults to 4.0.
"""
super(Clarinet, self).__init__()
self.encoder = encoder
self.teacher = teacher
self.student = student
self.stft = stft
self.lmd = lmd
self.min_log_scale = min_log_scale
def forward(self, audio, mel, audio_start, clip_kl=True):
"""Compute loss of Clarinet model.
Args:
audio (Variable): shape(B, T_audio), dtype flaot32, ground truth waveform.
mel (Variable): shape(B, F, T_mel), dtype flaot32, condition(mel spectrogram here).
audio_start (Variable): shape(B, ), dtype int64, audio starts positions.
clip_kl (bool, optional): whether to clip kl_loss by maximum=100. Defaults to True.
Returns:
Dict(str, Variable)
loss (Variable): shape(1, ), dtype flaot32, total loss.
kl (Variable): shape(1, ), dtype flaot32, kl divergence between the teacher's output distribution and student's output distribution.
regularization (Variable): shape(1, ), dtype flaot32, a regularization term of the KL divergence.
spectrogram_frame_loss (Variable): shape(1, ), dytpe: float, stft loss, the L1-distance of the magnitudes of the spectrograms of the ground truth waveform and synthesized waveform.
"""
batch_size, audio_length = audio.shape # audio clip's length
z = paddle.randn(audio.shape)
condition = self.encoder(mel) # (B, C, T)
condition_slice = crop(condition, audio_start, audio_length)
x, s_means, s_scales = self.student(z, condition_slice) # all [0: T]
s_means = s_means[:, 1:] # (B, T-1), time steps [1: T]
s_scales = s_scales[:, 1:] # (B, T-1), time steps [1: T]
s_clipped_scales = paddle.clip(s_scales, self.min_log_scale, 100.)
# teacher outputs single gaussian
y = self.teacher(x[:, :-1], condition_slice[:, :, 1:])
_, t_means, t_scales = paddle.chunk(y, 3, axis=-1) # time steps [1: T]
t_means = paddle.squeeze(t_means, [-1]) # (B, T-1), time steps [1: T]
t_scales = paddle.squeeze(t_scales, [-1]) # (B, T-1), time steps [1: T]
t_clipped_scales = paddle.clip(t_scales, self.min_log_scale, 100.)
s_distribution = D.Normal(s_means, paddle.exp(s_clipped_scales))
t_distribution = D.Normal(t_means, paddle.exp(t_clipped_scales))
# kl divergence loss, so we only need to sample once? no MC
kl = s_distribution.kl_divergence(t_distribution)
if clip_kl:
kl = paddle.clip(kl, -100., 10.)
# context size dropped
kl = paddle.reduce_mean(kl[:, self.teacher.context_size:])
# major diff here
regularization = F.mse_loss(t_scales[:, self.teacher.context_size:],
s_scales[:, self.teacher.context_size:])
# introduce information from real target
spectrogram_frame_loss = F.mse_loss(
self.stft.magnitude(audio), self.stft.magnitude(x))
loss = kl + self.lmd * regularization + spectrogram_frame_loss
loss_dict = {
"loss": loss,
"kl_divergence": kl,
"regularization": regularization,
"stft_loss": spectrogram_frame_loss
}
return loss_dict
@paddle.no_grad()
def synthesis(self, mel):
"""Synthesize waveform using the encoder and the student network.
Args:
mel (Variable): shape(B, F, T_mel), the condition(mel spectrogram here).
Returns:
Variable: shape(B, T_audio), the synthesized waveform. (T_audio = T_mel * upscale_factor, where upscale_factor is the `upscale_factor` of the encoder.)
"""
condition = self.encoder(mel)
samples_shape = (condition.shape[0], condition.shape[-1])
z = paddle.randn(samples_shape)
x, s_means, s_scales = self.student(z, condition)
return x
# TODO(chenfeiyu): ClariNetLoss