2020-10-10 15:51:54 +08:00
import paddle
from paddle import nn
from paddle . nn import functional as F
from paddle import distribution as D
from parakeet . models . wavenet import WaveNet , UpsampleNet , crop
2020-10-16 13:51:56 +08:00
__all__ = [ " Clarinet " ]
2020-10-10 15:51:54 +08:00
class ParallelWaveNet ( nn . LayerList ) :
def __init__ ( self , n_loops , n_layers , residual_channels , condition_dim ,
filter_size ) :
""" ParallelWaveNet, an inverse autoregressive flow model, it contains several flows(WaveNets).
Args :
n_loops ( List [ int ] ) : ` n_loop ` for each flow .
n_layers ( List [ int ] ) : ` n_layer ` for each flow .
residual_channels ( int ) : ` residual_channels ` for every flow .
condition_dim ( int ) : ` condition_dim ` for every flow .
filter_size ( int ) : ` filter_size ` for every flow .
"""
super ( ParallelWaveNet , self ) . __init__ ( )
for n_loop , n_layer in zip ( n_loops , n_layers ) :
# teacher's log_scale_min does not matter herem, -100 is a dummy value
self . append (
WaveNet ( n_loop , n_layer , residual_channels , 3 , condition_dim ,
filter_size , " mog " , - 100.0 ) )
def forward ( self , z , condition = None ) :
""" Transform a random noise sampled from a standard Gaussian distribution into sample from the target distribution. And output the mean and log standard deviation of the output distribution.
Args :
z ( Variable ) : shape ( B , T ) , random noise sampled from a standard gaussian disribution .
condition ( Variable , optional ) : shape ( B , F , T ) , dtype float , the upsampled condition . Defaults to None .
Returns :
( z , out_mu , out_log_std )
z ( Variable ) : shape ( B , T ) , dtype float , transformed noise , it is the synthesized waveform .
out_mu ( Variable ) : shape ( B , T ) , dtype float , means of the output distributions .
out_log_std ( Variable ) : shape ( B , T ) , dtype float , log standard deviations of the output distributions .
"""
for i , flow in enumerate ( self ) :
theta = flow ( z , condition ) # w, mu, log_std [0: T]
w , mu , log_std = paddle . chunk ( theta , 3 , axis = - 1 ) # (B, T, 1) for each
mu = paddle . squeeze ( mu , - 1 ) #[0: T]
log_std = paddle . squeeze ( log_std , - 1 ) #[0: T]
z = z * paddle . exp ( log_std ) + mu #[0: T]
if i == 0 :
out_mu = mu
out_log_std = log_std
else :
out_mu = out_mu * paddle . exp ( log_std ) + mu
out_log_std + = log_std
return z , out_mu , out_log_std
# Gaussian IAF model
class Clarinet ( nn . Layer ) :
def __init__ ( self , encoder , teacher , student , stft ,
min_log_scale = - 6.0 , lmd = 4.0 ) :
""" Clarinet model. Conditional Parallel WaveNet.
Args :
encoder ( UpsampleNet ) : an UpsampleNet to upsample mel spectrogram .
teacher ( WaveNet ) : a WaveNet , the teacher .
student ( ParallelWaveNet ) : a ParallelWaveNet model , the student .
stft ( STFT ) : a STFT model to perform differentiable stft transform .
min_log_scale ( float , optional ) : used only for computing loss , the minimal value of log standard deviation of the output distribution of both the teacher and the student . Defaults to - 6.0 .
lmd ( float , optional ) : weight for stft loss . Defaults to 4.0 .
"""
super ( Clarinet , self ) . __init__ ( )
self . encoder = encoder
self . teacher = teacher
self . student = student
self . stft = stft
self . lmd = lmd
self . min_log_scale = min_log_scale
def forward ( self , audio , mel , audio_start , clip_kl = True ) :
""" Compute loss of Clarinet model.
Args :
audio ( Variable ) : shape ( B , T_audio ) , dtype flaot32 , ground truth waveform .
mel ( Variable ) : shape ( B , F , T_mel ) , dtype flaot32 , condition ( mel spectrogram here ) .
audio_start ( Variable ) : shape ( B , ) , dtype int64 , audio starts positions .
clip_kl ( bool , optional ) : whether to clip kl_loss by maximum = 100. Defaults to True .
Returns :
Dict ( str , Variable )
loss ( Variable ) : shape ( 1 , ) , dtype flaot32 , total loss .
kl ( Variable ) : shape ( 1 , ) , dtype flaot32 , kl divergence between the teacher ' s output distribution and student ' s output distribution .
regularization ( Variable ) : shape ( 1 , ) , dtype flaot32 , a regularization term of the KL divergence .
spectrogram_frame_loss ( Variable ) : shape ( 1 , ) , dytpe : float , stft loss , the L1 - distance of the magnitudes of the spectrograms of the ground truth waveform and synthesized waveform .
"""
batch_size , audio_length = audio . shape # audio clip's length
z = paddle . randn ( audio . shape )
condition = self . encoder ( mel ) # (B, C, T)
condition_slice = crop ( condition , audio_start , audio_length )
x , s_means , s_scales = self . student ( z , condition_slice ) # all [0: T]
s_means = s_means [ : , 1 : ] # (B, T-1), time steps [1: T]
s_scales = s_scales [ : , 1 : ] # (B, T-1), time steps [1: T]
s_clipped_scales = paddle . clip ( s_scales , self . min_log_scale , 100. )
# teacher outputs single gaussian
y = self . teacher ( x [ : , : - 1 ] , condition_slice [ : , : , 1 : ] )
_ , t_means , t_scales = paddle . chunk ( y , 3 , axis = - 1 ) # time steps [1: T]
t_means = paddle . squeeze ( t_means , [ - 1 ] ) # (B, T-1), time steps [1: T]
t_scales = paddle . squeeze ( t_scales , [ - 1 ] ) # (B, T-1), time steps [1: T]
t_clipped_scales = paddle . clip ( t_scales , self . min_log_scale , 100. )
s_distribution = D . Normal ( s_means , paddle . exp ( s_clipped_scales ) )
t_distribution = D . Normal ( t_means , paddle . exp ( t_clipped_scales ) )
# kl divergence loss, so we only need to sample once? no MC
kl = s_distribution . kl_divergence ( t_distribution )
if clip_kl :
kl = paddle . clip ( kl , - 100. , 10. )
# context size dropped
kl = paddle . reduce_mean ( kl [ : , self . teacher . context_size : ] )
# major diff here
regularization = F . mse_loss ( t_scales [ : , self . teacher . context_size : ] ,
s_scales [ : , self . teacher . context_size : ] )
# introduce information from real target
spectrogram_frame_loss = F . mse_loss (
self . stft . magnitude ( audio ) , self . stft . magnitude ( x ) )
loss = kl + self . lmd * regularization + spectrogram_frame_loss
loss_dict = {
" loss " : loss ,
" kl_divergence " : kl ,
" regularization " : regularization ,
" stft_loss " : spectrogram_frame_loss
}
return loss_dict
@paddle.no_grad ( )
def synthesis ( self , mel ) :
""" Synthesize waveform using the encoder and the student network.
Args :
mel ( Variable ) : shape ( B , F , T_mel ) , the condition ( mel spectrogram here ) .
Returns :
Variable : shape ( B , T_audio ) , the synthesized waveform . ( T_audio = T_mel * upscale_factor , where upscale_factor is the ` upscale_factor ` of the encoder . )
"""
condition = self . encoder ( mel )
samples_shape = ( condition . shape [ 0 ] , condition . shape [ - 1 ] )
z = paddle . randn ( samples_shape )
x , s_means , s_scales = self . student ( z , condition )
return x
# TODO(chenfeiyu): ClariNetLoss