update docstring for waveflow
This commit is contained in:
parent
f2a35a17d4
commit
b6efb43990
|
@ -80,10 +80,6 @@ class Experiment(ExperimentBase):
|
|||
z, log_det_jocobian = self.model(wav, mel)
|
||||
return z, log_det_jocobian
|
||||
|
||||
def compute_losses(self, outputs):
|
||||
loss = self.criterion(outputs)
|
||||
return loss
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
|
@ -92,8 +88,8 @@ class Experiment(ExperimentBase):
|
|||
self.model.train()
|
||||
self.optimizer.clear_grad()
|
||||
mel, wav = batch
|
||||
outputs = self.compute_outputs(mel, wav)
|
||||
loss = self.compute_losses(outputs)
|
||||
z, log_det_jocobian = self.compute_outputs(mel, wav)
|
||||
loss = self.criterion(z, log_det_jocobian)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
@ -112,8 +108,8 @@ class Experiment(ExperimentBase):
|
|||
valid_iterator = iter(self.valid_loader)
|
||||
valid_losses = []
|
||||
mel, wav = next(valid_iterator)
|
||||
outputs = self.compute_outputs(mel, wav)
|
||||
loss = self.compute_losses(outputs)
|
||||
z, log_det_jocobian = self.compute_outputs(mel, wav)
|
||||
loss = self.criterion(z, log_det_jocobian)
|
||||
valid_losses.append(float(loss))
|
||||
valid_loss = np.mean(valid_losses)
|
||||
self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Tuple
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
@ -9,27 +9,56 @@ from paddle.nn import initializer as I
|
|||
from parakeet.utils import checkpoint
|
||||
from parakeet.modules import geometry as geo
|
||||
|
||||
__all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
|
||||
__all__ = ["WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
|
||||
|
||||
def fold(x, n_group):
|
||||
"""Fold audio or spectrogram's temporal dimension in to groups.
|
||||
r"""Fold audio or spectrogram's temporal dimension in to groups.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(*, time_steps), the input tensor
|
||||
n_group (int): the size of a group.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(\*, time_steps)
|
||||
The input tensor.
|
||||
|
||||
n_group : int
|
||||
The size of a group.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(*, time_steps // n_group, group), folded tensor.
|
||||
Returns
|
||||
---------
|
||||
Tensor : [shape=(`*, time_steps // n_group, group)]
|
||||
Folded tensor.
|
||||
"""
|
||||
*spatial_shape, time_steps = x.shape
|
||||
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
||||
return paddle.reshape(x, new_shape)
|
||||
|
||||
class UpsampleNet(nn.LayerList):
|
||||
"""
|
||||
Layer to upsample mel spectrogram to the same temporal resolution with
|
||||
the corresponding waveform. It consists of several conv2dtranspose layers
|
||||
which perform de convolution on mel and time dimension.
|
||||
"""Layer to upsample mel spectrogram to the same temporal resolution with
|
||||
the corresponding waveform.
|
||||
|
||||
It consists of several conv2dtranspose layers which perform deconvolution
|
||||
on mel and time dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
upscale_factors : List[int], optional
|
||||
Time upsampling factors for each Conv2DTranspose Layer.
|
||||
|
||||
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
|
||||
Layers. Each upscale_factor is used as the ``stride`` for the
|
||||
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
|
||||
upsampling factor is 256.
|
||||
|
||||
Notes
|
||||
------
|
||||
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
|
||||
transformation used to extract spectrogram features from audio.
|
||||
|
||||
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
|
||||
transformation whose ``hop_length`` equals 256 is suitable.
|
||||
|
||||
See Also
|
||||
---------
|
||||
``librosa.core.stft``
|
||||
"""
|
||||
def __init__(self, upsample_factors):
|
||||
super(UpsampleNet, self).__init__()
|
||||
|
@ -49,17 +78,25 @@ class UpsampleNet(nn.LayerList):
|
|||
self.upsample_factors = upsample_factors
|
||||
|
||||
def forward(self, x, trim_conv_artifact=False):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, input_channels, time_steps), the input
|
||||
spectrogram.
|
||||
trim_conv_artifact (bool, optional): trim deconvolution artifact at
|
||||
each layer. Defaults to False.
|
||||
r"""Forward pass of the ``UpsampleNet``.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch_size, input_channels, time_steps)]
|
||||
The input spectrogram.
|
||||
|
||||
trim_conv_artifact : bool, optional
|
||||
Trim deconvolution artifact at each layer. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, input_channels, time_steps * upsample_factor).
|
||||
If trim_conv_artifact is True, the output time steps is less
|
||||
than time_steps * upsample_factors.
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(batch_size, input_channels, time_steps \* upsample_factor)]
|
||||
The upsampled spectrogram.
|
||||
|
||||
Notes
|
||||
--------
|
||||
If trim_conv_artifact is ``True``, the output time steps is less
|
||||
than ``time_steps \* upsample_factors``.
|
||||
"""
|
||||
x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
|
||||
for layer in self:
|
||||
|
@ -72,11 +109,27 @@ class UpsampleNet(nn.LayerList):
|
|||
return x
|
||||
|
||||
|
||||
#TODO write doc
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""
|
||||
ResidualBlock, the basic unit of ResidualNet. It has a conv2d layer, which
|
||||
has causal padding in height dimension and same paddign in width dimension.
|
||||
It also has projection for the condition and output.
|
||||
"""ResidualBlock, the basic unit of ResidualNet used in WaveFlow.
|
||||
|
||||
It has a conv2d layer, which has causal padding in height dimension and
|
||||
same paddign in width dimension. It also has projection for the condition
|
||||
and output.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channels : int
|
||||
Feature size of the input.
|
||||
|
||||
cond_channels : int
|
||||
Featuer size of the condition.
|
||||
|
||||
kernel_size : Tuple[int]
|
||||
Kernel size of the Convolution2d applied to the input.
|
||||
|
||||
dilations : int
|
||||
Dilations of the Convolution2d applied to the input.
|
||||
"""
|
||||
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
@ -113,14 +166,21 @@ class ResidualBlock(nn.Layer):
|
|||
def forward(self, x, condition):
|
||||
"""Compute output for a whole folded sequence.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, channel, height, width), the input.
|
||||
condition (Tensor): shape(batch_size, condition_channel, height, width),
|
||||
the local condition.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The input.
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
Returns:
|
||||
res (Tensor): shape(batch_size, channel, height, width), the residual output.
|
||||
res (Tensor): shape(batch_size, channel, height, width), the skip output.
|
||||
Returns
|
||||
-------
|
||||
res : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The residual output.
|
||||
|
||||
skip : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The skip output.
|
||||
"""
|
||||
x_in = x
|
||||
x = self.conv(x)
|
||||
|
@ -131,10 +191,12 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
x = self.out_proj(x)
|
||||
res, skip = paddle.chunk(x, 2, axis=1)
|
||||
return x_in + res, skip
|
||||
res = x_in + res
|
||||
return res, skip
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
|
||||
"""Prepare the layer for incremental computation of causal
|
||||
convolution. Reset the buffer for causal convolution.
|
||||
|
||||
Raises:
|
||||
ValueError: If not in evaluation mode.
|
||||
|
@ -155,13 +217,21 @@ class ResidualBlock(nn.Layer):
|
|||
def add_input(self, x_row, condition_row):
|
||||
"""Compute the output for a row and update the buffer.
|
||||
|
||||
Args:
|
||||
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
|
||||
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
|
||||
Parameters
|
||||
----------
|
||||
x_row : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the input.
|
||||
|
||||
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
|
||||
A row of the condition.
|
||||
|
||||
Returns:
|
||||
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
|
||||
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
|
||||
Returns
|
||||
-------
|
||||
res : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the the residual output.
|
||||
|
||||
res : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the skip output.
|
||||
"""
|
||||
x_row_in = x_row
|
||||
if self._conv_buffer is None:
|
||||
|
@ -182,7 +252,8 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
x_row = self.out_proj(x_row)
|
||||
res, skip = paddle.chunk(x_row, 2, axis=1)
|
||||
return x_row_in + res, skip
|
||||
res = x_row_in + res
|
||||
return res, skip
|
||||
|
||||
def _init_buffer(self, input):
|
||||
batch_size, channels, _, width = input.shape
|
||||
|
@ -195,11 +266,36 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResidualNet(nn.LayerList):
|
||||
"""A stack of several ResidualBlocks. It merges condition at each layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_layer : int
|
||||
Number of ResidualBlocks in the ResidualNet.
|
||||
|
||||
residual_channels : int
|
||||
Feature size of each ResidualBlocks.
|
||||
|
||||
condition_channels : int
|
||||
Feature size of the condition.
|
||||
|
||||
kernel_size : Tuple[int]
|
||||
Kernel size of each ResidualBlock.
|
||||
|
||||
dilations_h : List[int]
|
||||
Dilation in height dimension of every ResidualBlock.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the length of dilations_h does not equals n_layers.
|
||||
"""
|
||||
A stack of several ResidualBlocks. It merges condition at each layer. All
|
||||
skip outputs are collected.
|
||||
"""
|
||||
def __init__(self, n_layer, residual_channels, condition_channels, kernel_size, dilations_h):
|
||||
def __init__(self,
|
||||
n_layer: int,
|
||||
residual_channels: int,
|
||||
condition_channels: int,
|
||||
kernel_size: Tuple[int],
|
||||
dilations_h: List[int]):
|
||||
if len(dilations_h) != n_layer:
|
||||
raise ValueError("number of dilations_h should equals num of layers")
|
||||
super(ResidualNet, self).__init__()
|
||||
|
@ -211,14 +307,18 @@ class ResidualNet(nn.LayerList):
|
|||
def forward(self, x, condition):
|
||||
"""Comput the output of given the input and the condition.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, channel, height, width), the input.
|
||||
condition (Tensor): shape(batch_size, condition_channel, height, width),
|
||||
the local condition.
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The input.
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, channel, height, width), the output, which
|
||||
is an aggregation of all the skip outputs.
|
||||
Returns
|
||||
--------
|
||||
Tensor : [shape=(batch_size, channel, height, width)]
|
||||
The output, which is an aggregation of all the skip outputs.
|
||||
"""
|
||||
skip_connections = []
|
||||
for layer in self:
|
||||
|
@ -228,20 +328,29 @@ class ResidualNet(nn.LayerList):
|
|||
return out
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the layer for incremental computation."""
|
||||
"""Prepare the layer for incremental computation.
|
||||
"""
|
||||
for layer in self:
|
||||
layer.start_sequence()
|
||||
|
||||
def add_input(self, x_row, condition_row):
|
||||
"""Compute the output for a row and update the buffer.
|
||||
"""Compute the output for a row and update the buffers.
|
||||
|
||||
Args:
|
||||
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
|
||||
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
|
||||
Parameters
|
||||
----------
|
||||
x_row : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the input.
|
||||
|
||||
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
|
||||
A row of the condition.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, channel, 1, width), the output, which is
|
||||
an aggregation of all the skip outputs.
|
||||
Returns
|
||||
-------
|
||||
res : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the the residual output.
|
||||
|
||||
res : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the skip output.
|
||||
"""
|
||||
skip_connections = []
|
||||
for layer in self:
|
||||
|
@ -252,12 +361,29 @@ class ResidualNet(nn.LayerList):
|
|||
|
||||
|
||||
class Flow(nn.Layer):
|
||||
"""
|
||||
A bijection (Reversable layer) that transform a density of latent variables
|
||||
p(Z) into a complex data distribution p(X).
|
||||
"""A bijection (Reversable layer) that transform a density of latent
|
||||
variables p(Z) into a complex data distribution p(X).
|
||||
|
||||
It's a auto regressive flow. The `forward` method implements the probability
|
||||
density estimation. The `inverse` method implements the sampling.
|
||||
It's an auto regressive flow. The `forward` method implements the
|
||||
probability density estimation. The `inverse` method implements the
|
||||
sampling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_layers : int
|
||||
Number of ResidualBlocks in the Flow.
|
||||
|
||||
channels : int
|
||||
Feature size of the ResidualBlocks.
|
||||
|
||||
mel_bands : int
|
||||
Feature size of the mel spectrogram (mel bands).
|
||||
|
||||
kernel_size : Tuple[int]
|
||||
Kernel size of each ResisualBlocks in the Flow.
|
||||
|
||||
n_group : int
|
||||
Number of timesteps to the folded into a group.
|
||||
"""
|
||||
dilations_dict = {
|
||||
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
|
@ -301,18 +427,29 @@ class Flow(nn.Layer):
|
|||
return z_out
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Probability density estimation. It is done by inversely transform a sample
|
||||
from p(X) back into a sample from p(Z).
|
||||
"""Probability density estimation. It is done by inversely transform
|
||||
a sample from p(X) into a sample from p(Z).
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(X).
|
||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch, 1, height, width)]
|
||||
A input sample of the distribution p(X).
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
Returns:
|
||||
(z, (logs, b))
|
||||
z (Tensor): shape(batch, 1, height, width), the transformed sample.
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
|
||||
Returns
|
||||
--------
|
||||
z (Tensor): shape(batch, 1, height, width), the transformed sample.
|
||||
|
||||
Tuple[Tensor, Tensor]
|
||||
The parameter of the transformation.
|
||||
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
||||
of the transformation from x to z.
|
||||
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
||||
transformation from x to z.
|
||||
"""
|
||||
# (B, C, H-1, W)
|
||||
logs, b = self._predict_parameters(
|
||||
|
@ -340,18 +477,30 @@ class Flow(nn.Layer):
|
|||
self.resnet.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
||||
and transform the sample. It is a auto regressive transformation.
|
||||
"""Sampling from the the distrition p(X). It is done by sample form
|
||||
p(Z) and transform the sample. It is a auto regressive transformation.
|
||||
|
||||
Args:
|
||||
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
|
||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
||||
Parameters
|
||||
-----------
|
||||
z : Tensor [shape=(batch, 1, height, width)]
|
||||
A sample of the distribution p(Z).
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
Returns:
|
||||
(x, (logs, b))
|
||||
x (Tensor): shape(batch, 1, height, width), the transformed sample.
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
|
||||
Returns
|
||||
---------
|
||||
x : Tensor [shape=(batch, 1, height, width)]
|
||||
The transformed sample.
|
||||
|
||||
Tuple[Tensor, Tensor]
|
||||
The parameter of the transformation.
|
||||
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
||||
of the transformation from x to z.
|
||||
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
||||
transformation from x to z.
|
||||
"""
|
||||
z_0 = z[:, :, :1, :]
|
||||
x = []
|
||||
|
@ -377,7 +526,29 @@ class Flow(nn.Layer):
|
|||
|
||||
|
||||
class WaveFlow(nn.LayerList):
|
||||
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
|
||||
"""An Deep Reversible layer that is composed of severel auto regressive
|
||||
flows.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
n_flows : int
|
||||
Number of flows in the WaveFlow model.
|
||||
|
||||
n_layers : int
|
||||
Number of ResidualBlocks in each Flow.
|
||||
|
||||
n_group : int
|
||||
Number of timesteps to fold as a group.
|
||||
|
||||
channels : int
|
||||
Feature size of each ResidualBlock.
|
||||
|
||||
mel_bands : int
|
||||
Feature size of mel spectrogram (mel bands).
|
||||
|
||||
kernel_size : Union[int, List[int]]
|
||||
Kernel size of the convolution layer in each ResidualBlock.
|
||||
"""
|
||||
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
||||
if n_group % 2 or n_flows % 2:
|
||||
raise ValueError("number of flows and number of group must be even "
|
||||
|
@ -416,15 +587,25 @@ class WaveFlow(nn.LayerList):
|
|||
return x, condition
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Probability density estimation.
|
||||
"""Probability density estimation of random variable x given the
|
||||
condition.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, time_steps), the audio.
|
||||
condition (Tensor): shape(batch_size, condition channel, time_steps), the local condition.
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch_size, time_steps)]
|
||||
The audio.
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition channel, time_steps)]
|
||||
The local condition (mel spectrogram here).
|
||||
|
||||
Returns:
|
||||
z: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
||||
log_det_jacobian: (Tensor), shape(1,), the log determinant of the jacobian of (dz/dx).
|
||||
Returns
|
||||
--------
|
||||
z : Tensor [shape=(batch_size, time_steps)]
|
||||
The transformed random variable.
|
||||
|
||||
log_det_jacobian: Tensor [shape=(1,)]
|
||||
The log determinant of the jacobian of the transformation from x
|
||||
to z.
|
||||
"""
|
||||
# x: (B, T)
|
||||
# condition: (B, C, T) upsampled condition
|
||||
|
@ -451,15 +632,24 @@ class WaveFlow(nn.LayerList):
|
|||
return z, log_det_jacobian
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
||||
and transform the sample. It is a auto regressive transformation.
|
||||
"""Sampling from the the distrition p(X).
|
||||
|
||||
It is done by sample a ``z`` form p(Z) and transform it into ``x``.
|
||||
Each Flow transform .. math:: `z_{i-1}` to .. math:: `z_{i}` in an
|
||||
autoregressive manner.
|
||||
|
||||
Args:
|
||||
z (Tensor): shape(batch, 1, time_steps), a input sample of the distribution p(Z).
|
||||
condition (Tensor): shape(batch, condition_channel, time_steps), the local condition.
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor [shape=(batch, 1, time_steps]
|
||||
A sample of the distribution p(Z).
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, time_steps)]
|
||||
The local condition.
|
||||
|
||||
Returns:
|
||||
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
||||
Returns
|
||||
--------
|
||||
x : Tensor [shape=(batch_size, time_steps)]
|
||||
The transformed sample (audio here).
|
||||
"""
|
||||
|
||||
z, condition = self._trim(z, condition)
|
||||
|
@ -480,6 +670,31 @@ class WaveFlow(nn.LayerList):
|
|||
|
||||
|
||||
class ConditionalWaveFlow(nn.LayerList):
|
||||
"""ConditionalWaveFlow, a UpsampleNet with a WaveFlow model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
upsample_factors : List[int]
|
||||
Upsample factors for the upsample net.
|
||||
|
||||
n_flows : int
|
||||
Number of flows in the WaveFlow model.
|
||||
|
||||
n_layers : int
|
||||
Number of ResidualBlocks in each Flow.
|
||||
|
||||
n_group : int
|
||||
Number of timesteps to fold as a group.
|
||||
|
||||
channels : int
|
||||
Feature size of each ResidualBlock.
|
||||
|
||||
n_mels : int
|
||||
Feature size of mel spectrogram (mel bands).
|
||||
|
||||
kernel_size : Union[int, List[int]]
|
||||
Kernel size of the convolution layer in each ResidualBlock.
|
||||
"""
|
||||
def __init__(self,
|
||||
upsample_factors: List[int],
|
||||
n_flows: int,
|
||||
|
@ -499,12 +714,44 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
kernel_size=kernel_size)
|
||||
|
||||
def forward(self, audio, mel):
|
||||
"""Compute the transformed random variable z (x to z) and the log of
|
||||
the determinant of the jacobian of the transformation from x to z.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : Tensor [shape=(B, T)]
|
||||
The audio.
|
||||
|
||||
mel : Tensor [shape=(B, C_mel, T_mel)]
|
||||
The mel spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
z : Tensor [shape=(B, T)]
|
||||
The inversely transformed random variable z (x to z)
|
||||
|
||||
log_det_jacobian: Tensor [shape=(1,)]
|
||||
the log of the determinant of the jacobian of the transformation
|
||||
from x to z.
|
||||
"""
|
||||
condition = self.encoder(mel)
|
||||
z, log_det_jacobian = self.decoder(audio, condition)
|
||||
return z, log_det_jacobian
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, mel):
|
||||
r"""Generate raw audio given mel spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mel : Tensor [shape=(B, C_mel, T_mel)]
|
||||
Mel spectrogram (in log-magnitude).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor : [shape=(B, T)]
|
||||
The synthesized audio, where``T <= T_mel \* upsample_factors``.
|
||||
"""
|
||||
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
||||
|
@ -513,6 +760,18 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
|
||||
@paddle.no_grad()
|
||||
def predict(self, mel):
|
||||
"""Generate raw audio given mel spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mel : np.ndarray [shape=(C_mel, T_mel)]
|
||||
Mel spectrogram of an utterance(in log-magnitude).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray [shape=(T,)]
|
||||
The synthesized audio.
|
||||
"""
|
||||
mel = paddle.to_tensor(mel)
|
||||
mel = paddle.unsqueeze(mel, 0)
|
||||
audio = self.infer(mel)
|
||||
|
@ -521,6 +780,21 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, checkpoint_path):
|
||||
"""Build a ConditionalWaveFlow model from a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
ConditionalWaveFlow
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(
|
||||
upsample_factors=config.model.upsample_factors,
|
||||
n_flows=config.model.n_flows,
|
||||
|
@ -534,14 +808,37 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
|
||||
|
||||
class WaveFlowLoss(nn.Layer):
|
||||
"""Criterion of a WaveFlow model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sigma : float
|
||||
The standard deviation of the gaussian noise used in WaveFlow, by
|
||||
default 1.0.
|
||||
"""
|
||||
def __init__(self, sigma=1.0):
|
||||
super(WaveFlowLoss, self).__init__()
|
||||
self.sigma = sigma
|
||||
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
||||
|
||||
def forward(self, model_output):
|
||||
z, log_det_jacobian = model_output
|
||||
def forward(self, z, log_det_jacobian):
|
||||
"""Compute the loss given the transformed random variable z and the
|
||||
log_det_jacobian of transformation from x to z.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor [shape=(B, T)]
|
||||
The transformed random variable (x to z).
|
||||
|
||||
log_det_jacobian : Tensor [shape=(1,)]
|
||||
The log of the determinant of the jacobian matrix of the
|
||||
transformation from x to z.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
|
||||
loss = loss / np.prod(z.shape)
|
||||
return loss + self.const
|
||||
|
|
|
@ -28,6 +28,7 @@ from parakeet.modules.conv import Conv1dCell
|
|||
from parakeet.modules.audio import quantize, dequantize, STFT
|
||||
from parakeet.utils import checkpoint, layer_tools
|
||||
|
||||
__all__ = ["WaveNet", "ConditionalWaveNet"]
|
||||
|
||||
def crop(x, audio_start, audio_length):
|
||||
"""Crop the upsampled condition to match audio_length.
|
||||
|
@ -285,21 +286,35 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResidualNet(nn.LayerList):
|
||||
"""The residual network in wavenet.
|
||||
|
||||
It consists of ``n_stack`` stacks, each of which consists of ``n_loop``
|
||||
ResidualBlocks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_stack : int
|
||||
Number of stacks in the ``ResidualNet``.
|
||||
|
||||
n_loop : int
|
||||
Number of ResidualBlocks in a stack.
|
||||
|
||||
residual_channels : int
|
||||
Input feature size of each ``ResidualBlock``'s input.
|
||||
|
||||
condition_dim : int
|
||||
Feature size of the condition.
|
||||
|
||||
filter_size : int
|
||||
Kernel size of the internal ``Conv1dCell`` of each ``ResidualBlock``.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
n_stack: int,
|
||||
n_loop: int,
|
||||
residual_channels: int,
|
||||
condition_dim: int,
|
||||
filter_size: int):
|
||||
"""The residual network in wavenet. It consists of `n_layer` stacks, each of which consists of `n_loop` ResidualBlocks.
|
||||
|
||||
Args:
|
||||
n_stack (int): number of stacks in the `ResidualNet`.
|
||||
n_loop (int): number of ResidualBlocks in a stack.
|
||||
residual_channels (int): channels of each `ResidualBlock`'s input.
|
||||
condition_dim (int): channels of the condition.
|
||||
filter_size (int): filter size of the internal Conv1DCell of each `ResidualBlock`.
|
||||
"""
|
||||
super(ResidualNet, self).__init__()
|
||||
# double the dilation at each layer in a stack
|
||||
dilations = [2**i for i in range(n_loop)] * n_stack
|
||||
|
@ -308,13 +323,21 @@ class ResidualNet(nn.LayerList):
|
|||
self.append(ResidualBlock(residual_channels, condition_dim, filter_size, dilation))
|
||||
|
||||
def forward(self, x, condition=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): shape(B, C_res, T), dtype float32, the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.)
|
||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32, the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels) Defaults to None.
|
||||
"""Forward pass of ``ResidualNet``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(B, C, T)]
|
||||
The input.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond, T)]
|
||||
The condition, it has been upsampled in time steps, so it has the
|
||||
same time steps as the input does. Defaults to None.
|
||||
|
||||
Returns:
|
||||
skip_connection (Tensor): shape(B, C_res, T), dtype float32, the output.
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, C, T)]
|
||||
The output.
|
||||
"""
|
||||
for i, func in enumerate(self):
|
||||
x, skip = func(x, condition)
|
||||
|
@ -326,22 +349,32 @@ class ResidualNet(nn.LayerList):
|
|||
return skip_connections
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the ResidualNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||
"""Prepare the ResidualNet to generate a new sequence. This method
|
||||
should be called before starting calling `add_input` multiple times.
|
||||
"""
|
||||
for block in self:
|
||||
block.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
|
||||
"""Take a step input and return a step output.
|
||||
|
||||
This method works similarily with ``forward`` but in a
|
||||
``step-in-step-out`` fashion.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B, C_res), dtype float32, input for a step.
|
||||
condition (Tensor, optional): shape(B, C_cond), dtype float32, condition for a step. Defaults to None.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(B, C)]
|
||||
Input for a step.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond)]
|
||||
Condition for a step. Defaults to None.
|
||||
|
||||
Returns:
|
||||
skip_connection (Tensor): shape(B, C_res), dtype float32, the output for a step.
|
||||
Returns
|
||||
----------
|
||||
Tensor [shape=(B, C)]
|
||||
T he skip connection for a step. This output is accumulated with
|
||||
that of other ResidualBlocks.
|
||||
"""
|
||||
|
||||
for i, func in enumerate(self):
|
||||
x, skip = func.add_input(x, condition)
|
||||
if i == 0:
|
||||
|
@ -353,20 +386,49 @@ class ResidualNet(nn.LayerList):
|
|||
|
||||
|
||||
class WaveNet(nn.Layer):
|
||||
"""Wavenet that transform upsampled mel spectrogram into waveform.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
n_stack : int
|
||||
``n_stack`` for the internal ``ResidualNet``.
|
||||
|
||||
n_loop : int
|
||||
``n_loop`` for the internal ``ResidualNet``.
|
||||
|
||||
residual_channels : int
|
||||
Feature size of the input.
|
||||
|
||||
output_dim : int
|
||||
Feature size of the input.
|
||||
|
||||
condition_dim : int
|
||||
Feature size of the condition (mel spectrogram bands).
|
||||
|
||||
filter_size : int
|
||||
Kernel size of the internal ``ResidualNet``.
|
||||
|
||||
loss_type : str, optional ["mog" or "softmax"]
|
||||
The output type and loss type of the model, by default "mog".
|
||||
|
||||
If "softmax", the model input is first quantized audio and the model
|
||||
outputs a discret categorical distribution.
|
||||
|
||||
If "mog", the model input is audio in floating point format, and the
|
||||
model outputs parameters for a mixture of gaussian distributions.
|
||||
Namely, the weight, mean and log scale of each gaussian distribution.
|
||||
Thus, the ``output_size`` should be a multiple of 3.
|
||||
|
||||
log_scale_min : float, optional
|
||||
Minimum value of the log scale of gaussian distributions, by default
|
||||
-9.0.
|
||||
|
||||
This is only used for computing loss when ``loss_type`` is "mog", If
|
||||
the predicted log scale is less than -9.0, it is clipped at -9.0.
|
||||
"""
|
||||
def __init__(self, n_stack, n_loop, residual_channels, output_dim,
|
||||
condition_dim, filter_size, loss_type, log_scale_min):
|
||||
"""Wavenet that transform upsampled mel spectrogram into waveform.
|
||||
|
||||
Args:
|
||||
n_stack (int): n_stack for the internal ResidualNet.
|
||||
n_loop (int): n_loop for the internal ResidualNet.
|
||||
residual_channels (int): the channel of the input.
|
||||
output_dim (int): the channel of the output distribution.
|
||||
condition_dim (int): the channel of the condition.
|
||||
filter_size (int): the filter size of the internal ResidualNet.
|
||||
loss_type (str): loss type of the wavenet. Possible values are 'softmax' and 'mog'. If `loss_type` is 'softmax', the output is the logits of the catrgotical(multinomial) distribution, `output_dim` means the number of classes of the categorical distribution. If `loss_type` is mog(mixture of gaussians), the output is the parameters of a mixture of gaussians, which consists of weight(in the form of logit) of each gaussian distribution and its mean and log standard deviaton. So when `loss_type` is 'mog', `output_dim` should be perfectly divided by 3.
|
||||
log_scale_min (int): the minimum value of log standard deviation of the output gaussian distributions. Note that this value is only used for computing loss if `loss_type` is 'mog', values less than `log_scale_min` is clipped when computing loss.
|
||||
"""
|
||||
super(WaveNet, self).__init__()
|
||||
if loss_type not in ["softmax", "mog"]:
|
||||
raise ValueError("loss_type {} is not supported".format(loss_type))
|
||||
|
@ -396,14 +458,19 @@ class WaveNet(nn.Layer):
|
|||
self.log_scale_min = log_scale_min
|
||||
|
||||
def forward(self, x, condition=None):
|
||||
"""compute the output distribution (represented by its parameters).
|
||||
"""Forward pass of ``WaveNet``.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B, T), dtype float32, the input waveform.
|
||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32, the upsampled condition. Defaults to None.
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
condition : Tensor, optional [shape=(B, C_cond, T)]
|
||||
the upsampled condition. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T, C_output), dtype float32, the parameter of the output distributions.
|
||||
Returns
|
||||
-------
|
||||
Tensor: [shape=(B, T, C_output)]
|
||||
The parameters of the output distributions.
|
||||
"""
|
||||
|
||||
# Causal Conv
|
||||
|
@ -426,19 +493,28 @@ class WaveNet(nn.Layer):
|
|||
return y
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the WaveNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||
"""Prepare the WaveNet to generate a new sequence. This method should
|
||||
be called before starting calling ``add_input`` multiple times.
|
||||
"""
|
||||
self.resnet.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""compute the output distribution (represented by its parameters) for a step. It works similarily with the `forward` method but in a `step-in-step-out` fashion.
|
||||
"""Compute the output distribution (represented by its parameters) for
|
||||
a step. It works similarily with the ``forward`` method but in a
|
||||
``step-in-step-out`` fashion.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B,), dtype float32, a step of the input waveform.
|
||||
condition (Tensor, optional): shape(B, C_cond, ), dtype float32, a step of the upsampled condition. Defaults to None.
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(B,)]
|
||||
A step of the input waveform.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond)]
|
||||
A step of the upsampled condition. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, C_output), dtype float32, the parameter of the output distributions.
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(B, C_output)]
|
||||
A steo of the parameters of the output distributions.
|
||||
"""
|
||||
# Causal Conv
|
||||
if self.loss_type == "softmax":
|
||||
|
@ -458,14 +534,28 @@ class WaveNet(nn.Layer):
|
|||
return y
|
||||
|
||||
def compute_softmax_loss(self, y, t):
|
||||
"""compute the loss where output distribution is a categorial distribution.
|
||||
"""Compute the loss when output distributions are categorial
|
||||
distributions.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the logits of the output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The logits of the output distributions.
|
||||
|
||||
t : Tensor [shape=(B, T)]
|
||||
The target audio. The audio is first quantized then used as the
|
||||
target.
|
||||
|
||||
Notes
|
||||
-------
|
||||
Output distributions whose input contains padding is neglected in
|
||||
loss computation. So the first ``context_size`` steps does not
|
||||
contribute to the loss.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1, ), dtype float32, the loss.
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
# context size is not taken into account
|
||||
y = y[:, self.context_size:, :]
|
||||
|
@ -479,13 +569,18 @@ class WaveNet(nn.Layer):
|
|||
return reduced_loss
|
||||
|
||||
def sample_from_softmax(self, y):
|
||||
"""Sample from the output distribution where the output distribution is a categorical distriobution.
|
||||
"""Sample from the output distribution when the output distributions
|
||||
are categorical distriobutions.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), the logits of the output distribution
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The logits of the output distributions.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, T)]
|
||||
Waveform sampled from the output distribution.
|
||||
"""
|
||||
# dequantize
|
||||
batch_size, time_steps, output_dim, = y.shape
|
||||
|
@ -497,14 +592,32 @@ class WaveNet(nn.Layer):
|
|||
return samples
|
||||
|
||||
def compute_mog_loss(self, y, t):
|
||||
"""compute the loss where output distribution is a mixture of Gaussians.
|
||||
"""Compute the loss where output distributions is a mixture of
|
||||
Gaussians distributions.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||
Parameters
|
||||
-----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution. It is the concatenation
|
||||
of 3 parts, the logits of every distribution, the mean of each
|
||||
distribution and the log standard deviation of each distribution.
|
||||
|
||||
Each part's shape is (B, T, n_mixture), where ``n_mixture`` means
|
||||
the number of Gaussians in the mixture.
|
||||
|
||||
t : Tensor [shape=(B, T)]
|
||||
The target audio.
|
||||
|
||||
Notes
|
||||
-------
|
||||
Output distributions whose input contains padding is neglected in
|
||||
loss computation. So the first ``context_size`` steps does not
|
||||
contribute to the loss.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1, ), dtype float32, the loss.
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
n_mixture = self.output_dim // 3
|
||||
|
||||
|
@ -536,12 +649,23 @@ class WaveNet(nn.Layer):
|
|||
return loss
|
||||
|
||||
def sample_from_mog(self, y):
|
||||
"""Sample from the output distribution where the output distribution is a mixture of Gaussians.
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
|
||||
"""Sample from the output distribution when the output distribution
|
||||
is a mixture of Gaussian distributions.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution. It is the concatenation
|
||||
of 3 parts, the logits of every distribution, the mean of each
|
||||
distribution and the log standard deviation of each distribution.
|
||||
|
||||
Each part's shape is (B, T, n_mixture), where ``n_mixture`` means
|
||||
the number of Gaussians in the mixture.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(B, T)]
|
||||
Waveform sampled from the output distribution.
|
||||
"""
|
||||
batch_size, time_steps, output_dim = y.shape
|
||||
n_mixture = output_dim // 3
|
||||
|
@ -568,11 +692,16 @@ class WaveNet(nn.Layer):
|
|||
|
||||
def sample(self, y):
|
||||
"""Sample from the output distribution.
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, T)]
|
||||
Waveform sampled from the output distribution.
|
||||
"""
|
||||
if self.loss_type == "softmax":
|
||||
return self.sample_from_softmax(y)
|
||||
|
@ -580,14 +709,20 @@ class WaveNet(nn.Layer):
|
|||
return self.sample_from_mog(y)
|
||||
|
||||
def loss(self, y, t):
|
||||
"""compute the loss where output distribution is a mixture of Gaussians.
|
||||
"""Compute the loss given the output distribution and the target.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution.
|
||||
|
||||
t : Tensor [shape=(B, T)]
|
||||
The target audio.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1, ), dtype float32, the loss.
|
||||
Returns
|
||||
---------
|
||||
Tensor: [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
if self.loss_type == "softmax":
|
||||
return self.compute_softmax_loss(y, t)
|
||||
|
@ -640,9 +775,11 @@ class ConditionalWaveNet(nn.Layer):
|
|||
Thus, the ``output_size`` should be a multiple of 3.
|
||||
|
||||
log_scale_min : float, optional
|
||||
Minimum value of the log probability density, by default -9.0.
|
||||
Minimum value of the log scale of gaussian distributions, by default
|
||||
-9.0.
|
||||
|
||||
This is only used for computing loss when ``loss_type`` is "mog", If the
|
||||
This is only used for computing loss when ``loss_type`` is "mog", If
|
||||
the predicted log scale is less than -9.0, it is clipped at -9.0.
|
||||
"""
|
||||
def __init__(self,
|
||||
upsample_factors: List[int],
|
||||
|
|
Loading…
Reference in New Issue