ParakeetEricRoss/parakeet/models/waveflow.py

845 lines
28 KiB
Python
Raw Normal View History

2020-10-10 15:51:54 +08:00
import math
import numpy as np
2020-12-19 18:33:07 +08:00
from typing import List, Union, Tuple
2020-10-10 15:51:54 +08:00
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from parakeet.utils import checkpoint
2020-10-10 15:51:54 +08:00
from parakeet.modules import geometry as geo
2020-12-19 18:33:07 +08:00
__all__ = ["WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
2020-10-10 15:51:54 +08:00
def fold(x, n_group):
2020-12-19 18:33:07 +08:00
r"""Fold audio or spectrogram's temporal dimension in to groups.
2020-10-10 15:51:54 +08:00
2020-12-19 18:33:07 +08:00
Parameters
----------
x : Tensor [shape=(\*, time_steps)
The input tensor.
n_group : int
The size of a group.
2020-10-10 15:51:54 +08:00
2020-12-19 18:33:07 +08:00
Returns
---------
Tensor : [shape=(`*, time_steps // n_group, group)]
Folded tensor.
2020-10-10 15:51:54 +08:00
"""
*spatial_shape, time_steps = x.shape
new_shape = spatial_shape + [time_steps // n_group, n_group]
return paddle.reshape(x, new_shape)
class UpsampleNet(nn.LayerList):
2020-12-19 18:33:07 +08:00
"""Layer to upsample mel spectrogram to the same temporal resolution with
the corresponding waveform.
It consists of several conv2dtranspose layers which perform deconvolution
on mel and time dimension.
Parameters
----------
upscale_factors : List[int], optional
Time upsampling factors for each Conv2DTranspose Layer.
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
Layers. Each upscale_factor is used as the ``stride`` for the
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
upsampling factor is 256.
Notes
------
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
transformation used to extract spectrogram features from audio.
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
transformation whose ``hop_length`` equals 256 is suitable.
See Also
---------
``librosa.core.stft``
2020-11-04 01:37:49 +08:00
"""
def __init__(self, upsample_factors):
2020-10-10 15:51:54 +08:00
super(UpsampleNet, self).__init__()
for factor in upsample_factors:
std = math.sqrt(1 / (3 * 2 * factor))
init = I.Uniform(-std, std)
self.append(
nn.utils.weight_norm(
2020-11-04 01:37:49 +08:00
nn.Conv2DTranspose(1, 1, (3, 2 * factor),
2020-10-10 15:51:54 +08:00
padding=(1, factor // 2),
stride=(1, factor),
weight_attr=init,
bias_attr=init)))
# upsample factors
self.upsample_factor = np.prod(upsample_factors)
self.upsample_factors = upsample_factors
def forward(self, x, trim_conv_artifact=False):
2020-12-19 18:33:07 +08:00
r"""Forward pass of the ``UpsampleNet``.
Parameters
-----------
x : Tensor [shape=(batch_size, input_channels, time_steps)]
The input spectrogram.
trim_conv_artifact : bool, optional
Trim deconvolution artifact at each layer. Defaults to False.
Returns
--------
Tensor: [shape=(batch_size, input_channels, time_steps \* upsample_factor)]
The upsampled spectrogram.
Notes
--------
If trim_conv_artifact is ``True``, the output time steps is less
than ``time_steps \* upsample_factors``.
2020-10-10 15:51:54 +08:00
"""
x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
2020-10-10 15:51:54 +08:00
for layer in self:
x = layer(x)
if trim_conv_artifact:
time_cutoff = layer._kernel_size[1] - layer._stride[1]
x = x[:, :, :, :-time_cutoff]
2020-10-10 15:51:54 +08:00
x = F.leaky_relu(x, 0.4)
x = paddle.squeeze(x, 1) # back to (B, C, T)
2020-10-10 15:51:54 +08:00
return x
2020-12-19 18:33:07 +08:00
#TODO write doc
2020-10-10 15:51:54 +08:00
class ResidualBlock(nn.Layer):
2020-12-19 18:33:07 +08:00
"""ResidualBlock, the basic unit of ResidualNet used in WaveFlow.
It has a conv2d layer, which has causal padding in height dimension and
same paddign in width dimension. It also has projection for the condition
and output.
Parameters
----------
channels : int
Feature size of the input.
cond_channels : int
Featuer size of the condition.
kernel_size : Tuple[int]
Kernel size of the Convolution2d applied to the input.
dilations : int
Dilations of the Convolution2d applied to the input.
2020-11-04 01:37:49 +08:00
"""
2020-10-10 15:51:54 +08:00
def __init__(self, channels, cond_channels, kernel_size, dilations):
super(ResidualBlock, self).__init__()
# input conv
std = math.sqrt(1 / channels * np.prod(kernel_size))
init = I.Uniform(-std, std)
2020-11-04 01:37:49 +08:00
receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)]
rh, rw = receptive_field
paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same
conv = nn.Conv2D(channels, 2 * channels, kernel_size,
padding=paddings,
dilation=dilations,
weight_attr=init,
bias_attr=init)
2020-10-10 15:51:54 +08:00
self.conv = nn.utils.weight_norm(conv)
self.rh = rh
self.rw = rw
self.dilations = dilations
2020-10-10 15:51:54 +08:00
# condition projection
std = math.sqrt(1 / cond_channels)
init = I.Uniform(-std, std)
2020-11-04 01:37:49 +08:00
condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1),
2020-10-10 15:51:54 +08:00
weight_attr=init, bias_attr=init)
self.condition_proj = nn.utils.weight_norm(condition_proj)
# parametric residual & skip connection
std = math.sqrt(1 / channels)
init = I.Uniform(-std, std)
2020-11-04 01:37:49 +08:00
out_proj = nn.Conv2D(channels, 2 * channels, (1, 1),
weight_attr=init, bias_attr=init)
2020-10-10 15:51:54 +08:00
self.out_proj = nn.utils.weight_norm(out_proj)
def forward(self, x, condition):
"""Compute output for a whole folded sequence.
2020-12-19 18:33:07 +08:00
Parameters
----------
x : Tensor [shape=(batch_size, channel, height, width)]
The input.
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
The local condition.
2020-12-19 18:33:07 +08:00
Returns
-------
res : Tensor [shape=(batch_size, channel, height, width)]
The residual output.
skip : Tensor [shape=(batch_size, channel, height, width)]
The skip output.
"""
x_in = x
2020-11-04 01:37:49 +08:00
x = self.conv(x)
2020-10-10 15:51:54 +08:00
x += self.condition_proj(condition)
content, gate = paddle.chunk(x, 2, axis=1)
x = paddle.tanh(content) * F.sigmoid(gate)
x = self.out_proj(x)
res, skip = paddle.chunk(x, 2, axis=1)
2020-12-19 18:33:07 +08:00
res = x_in + res
return res, skip
2020-11-04 01:37:49 +08:00
def start_sequence(self):
2020-12-19 18:33:07 +08:00
"""Prepare the layer for incremental computation of causal
convolution. Reset the buffer for causal convolution.
Raises:
ValueError: If not in evaluation mode.
"""
if self.training:
raise ValueError("Only use start sequence at evaluation mode.")
self._conv_buffer = None
# NOTE: call self.conv's weight norm hook expliccitly since
# its weight will be visited directly in `add_input` without
# calling its `__call__` method. If we do not trigger the weight
# norm hook, the weight may be outdated. e.g. after loading from
# a saved checkpoint
# see also: https://github.com/pytorch/pytorch/issues/47588
for hook in self.conv._forward_pre_hooks.values():
hook(self.conv, None)
def add_input(self, x_row, condition_row):
"""Compute the output for a row and update the buffer.
2020-12-19 18:33:07 +08:00
Parameters
----------
x_row : Tensor [shape=(batch_size, channel, 1, width)]
A row of the input.
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
A row of the condition.
2020-12-19 18:33:07 +08:00
Returns
-------
res : Tensor [shape=(batch_size, channel, 1, width)]
A row of the the residual output.
res : Tensor [shape=(batch_size, channel, 1, width)]
A row of the skip output.
"""
x_row_in = x_row
if self._conv_buffer is None:
self._init_buffer(x_row)
self._update_buffer(x_row)
rw = self.rw
x_row = F.conv2d(
self._conv_buffer,
self.conv.weight,
self.conv.bias,
padding=[0, 0, rw // 2, (rw - 1) // 2],
dilation=self.dilations)
x_row += self.condition_proj(condition_row)
content, gate = paddle.chunk(x_row, 2, axis=1)
x_row = paddle.tanh(content) * F.sigmoid(gate)
x_row = self.out_proj(x_row)
res, skip = paddle.chunk(x_row, 2, axis=1)
2020-12-19 18:33:07 +08:00
res = x_row_in + res
return res, skip
def _init_buffer(self, input):
batch_size, channels, _, width = input.shape
self._conv_buffer = paddle.zeros(
[batch_size, channels, self.rh, width], dtype=input.dtype)
def _update_buffer(self, input):
self._conv_buffer = paddle.concat(
[self._conv_buffer[:, :, 1:, :], input], axis=2)
2020-11-04 01:37:49 +08:00
2020-10-10 15:51:54 +08:00
class ResidualNet(nn.LayerList):
2020-12-19 18:33:07 +08:00
"""A stack of several ResidualBlocks. It merges condition at each layer.
Parameters
----------
n_layer : int
Number of ResidualBlocks in the ResidualNet.
residual_channels : int
Feature size of each ResidualBlocks.
condition_channels : int
Feature size of the condition.
kernel_size : Tuple[int]
Kernel size of each ResidualBlock.
dilations_h : List[int]
Dilation in height dimension of every ResidualBlock.
Raises
------
ValueError
If the length of dilations_h does not equals n_layers.
2020-11-04 01:37:49 +08:00
"""
2020-12-19 18:33:07 +08:00
def __init__(self,
n_layer: int,
residual_channels: int,
condition_channels: int,
kernel_size: Tuple[int],
dilations_h: List[int]):
2020-10-10 15:51:54 +08:00
if len(dilations_h) != n_layer:
raise ValueError("number of dilations_h should equals num of layers")
super(ResidualNet, self).__init__()
for i in range(n_layer):
dilation = (dilations_h[i], 2 ** i)
layer = ResidualBlock(residual_channels, condition_channels, kernel_size, dilation)
self.append(layer)
def forward(self, x, condition):
"""Comput the output of given the input and the condition.
2020-12-19 18:33:07 +08:00
Parameters
-----------
x : Tensor [shape=(batch_size, channel, height, width)]
The input.
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
The local condition.
2020-12-19 18:33:07 +08:00
Returns
--------
Tensor : [shape=(batch_size, channel, height, width)]
The output, which is an aggregation of all the skip outputs.
"""
2020-10-10 15:51:54 +08:00
skip_connections = []
for layer in self:
x, skip = layer(x, condition)
skip_connections.append(skip)
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
return out
2020-11-04 01:37:49 +08:00
def start_sequence(self):
2020-12-19 18:33:07 +08:00
"""Prepare the layer for incremental computation.
"""
for layer in self:
layer.start_sequence()
def add_input(self, x_row, condition_row):
2020-12-19 18:33:07 +08:00
"""Compute the output for a row and update the buffers.
2020-12-19 18:33:07 +08:00
Parameters
----------
x_row : Tensor [shape=(batch_size, channel, 1, width)]
A row of the input.
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
A row of the condition.
2020-12-19 18:33:07 +08:00
Returns
-------
res : Tensor [shape=(batch_size, channel, 1, width)]
A row of the the residual output.
res : Tensor [shape=(batch_size, channel, 1, width)]
A row of the skip output.
"""
skip_connections = []
for layer in self:
x_row, skip = layer.add_input(x_row, condition_row)
skip_connections.append(skip)
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
return out
2020-10-10 15:51:54 +08:00
class Flow(nn.Layer):
2020-12-19 18:33:07 +08:00
"""A bijection (Reversable layer) that transform a density of latent
variables p(Z) into a complex data distribution p(X).
2020-12-19 18:33:07 +08:00
It's an auto regressive flow. The `forward` method implements the
probability density estimation. The `inverse` method implements the
sampling.
Parameters
----------
n_layers : int
Number of ResidualBlocks in the Flow.
channels : int
Feature size of the ResidualBlocks.
mel_bands : int
Feature size of the mel spectrogram (mel bands).
kernel_size : Tuple[int]
Kernel size of each ResisualBlocks in the Flow.
n_group : int
Number of timesteps to the folded into a group.
2020-11-04 01:37:49 +08:00
"""
2020-10-10 15:51:54 +08:00
dilations_dict = {
8: [1, 1, 1, 1, 1, 1, 1, 1],
16: [1, 1, 1, 1, 1, 1, 1, 1],
32: [1, 2, 4, 1, 2, 4, 1, 2],
64: [1, 2, 4, 8, 16, 1, 2, 4],
128: [1, 2, 4, 8, 16, 32, 64, 1]
}
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
super(Flow, self).__init__()
# input projection
self.input_proj = nn.utils.weight_norm(
2020-11-04 01:37:49 +08:00
nn.Conv2D(1, channels, (1, 1),
2020-10-10 15:51:54 +08:00
weight_attr=I.Uniform(-1., 1.),
bias_attr=I.Uniform(-1., 1.)))
# residual net
self.resnet = ResidualNet(n_layers, channels, mel_bands, kernel_size,
self.dilations_dict[n_group])
# output projection
self.output_proj = nn.Conv2D(channels, 2, (1, 1),
weight_attr=I.Constant(0.),
bias_attr=I.Constant(0.))
# specs
self.n_group = n_group
def _predict_parameters(self, x, condition):
x = self.input_proj(x)
x = self.resnet(x, condition)
bijection_params = self.output_proj(x)
logs, b = paddle.chunk(bijection_params, 2, axis=1)
return logs, b
def _transform(self, x, logs, b):
z_0 = x[:, :, :1, :] # the first row, just copy it
z_out = x[:, :, 1:, :] * paddle.exp(logs) + b
z_out = paddle.concat([z_0, z_out], axis=2)
return z_out
2020-10-10 15:51:54 +08:00
def forward(self, x, condition):
2020-12-19 18:33:07 +08:00
"""Probability density estimation. It is done by inversely transform
a sample from p(X) into a sample from p(Z).
Parameters
-----------
x : Tensor [shape=(batch, 1, height, width)]
A input sample of the distribution p(X).
condition : Tensor [shape=(batch, condition_channel, height, width)]
The local condition.
Returns
--------
z (Tensor): shape(batch, 1, height, width), the transformed sample.
Tuple[Tensor, Tensor]
The parameter of the transformation.
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
of the transformation from x to z.
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
transformation from x to z.
"""
# (B, C, H-1, W)
logs, b = self._predict_parameters(
x[:, :, :-1, :], condition[:, :, 1:, :])
z = self._transform(x, logs, b)
return z, (logs, b)
2020-10-10 15:51:54 +08:00
def _predict_row_parameters(self, x_row, condition_row):
x_row = self.input_proj(x_row)
x_row = self.resnet.add_input(x_row, condition_row)
bijection_params = self.output_proj(x_row)
logs, b = paddle.chunk(bijection_params, 2, axis=1)
return logs, b
2020-10-10 15:51:54 +08:00
def _inverse_transform_row(self, z_row, logs, b):
x_row = (z_row - b) * paddle.exp(-logs)
return x_row
def _inverse_row(self, z_row, x_row, condition_row):
logs, b = self._predict_row_parameters(x_row, condition_row)
x_next_row = self._inverse_transform_row(z_row, logs, b)
return x_next_row, (logs, b)
def _start_sequence(self):
self.resnet.start_sequence()
def inverse(self, z, condition):
2020-12-19 18:33:07 +08:00
"""Sampling from the the distrition p(X). It is done by sample form
p(Z) and transform the sample. It is a auto regressive transformation.
Parameters
-----------
z : Tensor [shape=(batch, 1, height, width)]
A sample of the distribution p(Z).
condition : Tensor [shape=(batch, condition_channel, height, width)]
The local condition.
Returns
---------
x : Tensor [shape=(batch, 1, height, width)]
The transformed sample.
Tuple[Tensor, Tensor]
The parameter of the transformation.
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
of the transformation from x to z.
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
transformation from x to z.
"""
z_0 = z[:, :, :1, :]
x = []
logs_list = []
b_list = []
x.append(z_0)
self._start_sequence()
for i in range(1, self.n_group):
x_row = x[-1] # actuallt i-1:i
z_row = z[:, :, i:i+1, :]
condition_row = condition[:, :, i:i+1, :]
x_next_row, (logs, b) = self._inverse_row(z_row, x_row, condition_row)
x.append(x_next_row)
logs_list.append(logs)
b_list.append(b)
x = paddle.concat(x, 2)
logs = paddle.concat(logs_list, 2)
b = paddle.concat(b_list, 2)
return x, (logs, b)
2020-10-10 15:51:54 +08:00
class WaveFlow(nn.LayerList):
2020-12-19 18:33:07 +08:00
"""An Deep Reversible layer that is composed of severel auto regressive
flows.
Parameters
-----------
n_flows : int
Number of flows in the WaveFlow model.
n_layers : int
Number of ResidualBlocks in each Flow.
n_group : int
Number of timesteps to fold as a group.
channels : int
Feature size of each ResidualBlock.
mel_bands : int
Feature size of mel spectrogram (mel bands).
kernel_size : Union[int, List[int]]
Kernel size of the convolution layer in each ResidualBlock.
"""
2020-10-10 15:51:54 +08:00
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
if n_group % 2 or n_flows % 2:
raise ValueError("number of flows and number of group must be even "
"since a permutation along group among flows is used.")
super(WaveFlow, self).__init__()
2020-11-04 01:37:49 +08:00
for _ in range(n_flows):
2020-10-10 15:51:54 +08:00
self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group))
# permutations in h
2020-11-04 01:37:49 +08:00
self.perms = self._create_perm(n_group, n_flows)
# specs
self.n_group = n_group
self.n_flows = n_flows
def _create_perm(self, n_group, n_flows):
2020-10-10 15:51:54 +08:00
indices = list(range(n_group))
half = n_group // 2
2020-11-04 01:37:49 +08:00
perms = []
2020-10-10 15:51:54 +08:00
for i in range(n_flows):
if i < n_flows // 2:
2020-11-04 01:37:49 +08:00
perms.append(indices[::-1])
2020-10-10 15:51:54 +08:00
else:
perm = list(reversed(indices[:half])) + list(reversed(indices[half:]))
2020-11-04 01:37:49 +08:00
perms.append(perm)
return perms
2020-10-10 15:51:54 +08:00
2020-11-04 01:37:49 +08:00
def _trim(self, x, condition):
2020-10-10 15:51:54 +08:00
assert condition.shape[-1] >= x.shape[-1]
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
if x.shape[-1] > pruned_len:
x = x[:, :pruned_len]
if condition.shape[-1] > pruned_len:
condition = condition[:, :, :pruned_len]
return x, condition
def forward(self, x, condition):
2020-12-19 18:33:07 +08:00
"""Probability density estimation of random variable x given the
condition.
2020-12-19 18:33:07 +08:00
Parameters
-----------
x : Tensor [shape=(batch_size, time_steps)]
The audio.
condition : Tensor [shape=(batch_size, condition channel, time_steps)]
The local condition (mel spectrogram here).
2020-12-19 18:33:07 +08:00
Returns
--------
z : Tensor [shape=(batch_size, time_steps)]
The transformed random variable.
log_det_jacobian: Tensor [shape=(1,)]
The log determinant of the jacobian of the transformation from x
to z.
"""
2020-10-10 15:51:54 +08:00
# x: (B, T)
# condition: (B, C, T) upsampled condition
2020-11-04 01:37:49 +08:00
x, condition = self._trim(x, condition)
2020-10-10 15:51:54 +08:00
# to (B, C, h, T//h) layout
2020-10-10 15:51:54 +08:00
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
# flows
logs_list = []
for i, layer in enumerate(self):
x, (logs, b) = layer(x, condition)
2020-10-10 15:51:54 +08:00
logs_list.append(logs)
# permute paddle has no shuffle dim
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
2020-11-04 01:37:49 +08:00
z = paddle.squeeze(x, 1) # (B, H, W)
batch_size = z.shape[0]
z = paddle.reshape(paddle.transpose(z, [0, 2, 1]), [batch_size, -1])
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
return z, log_det_jacobian
2020-10-10 15:51:54 +08:00
def inverse(self, z, condition):
2020-12-19 18:33:07 +08:00
"""Sampling from the the distrition p(X).
It is done by sample a ``z`` form p(Z) and transform it into ``x``.
Each Flow transform .. math:: `z_{i-1}` to .. math:: `z_{i}` in an
autoregressive manner.
Parameters
----------
z : Tensor [shape=(batch, 1, time_steps]
A sample of the distribution p(Z).
condition : Tensor [shape=(batch, condition_channel, time_steps)]
The local condition.
2020-12-19 18:33:07 +08:00
Returns
--------
x : Tensor [shape=(batch_size, time_steps)]
The transformed sample (audio here).
"""
z, condition = self._trim(z, condition)
# to (B, C, h, T//h) layout
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
# reverse it flow by flow
for i in reversed(range(self.n_flows)):
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z, (logs, b) = self[i].inverse(z, condition)
x = paddle.squeeze(z, 1) # (B, H, W)
batch_size = x.shape[0]
x = paddle.reshape(paddle.transpose(x, [0, 2, 1]), [batch_size, -1])
return x
2020-10-10 15:51:54 +08:00
class ConditionalWaveFlow(nn.LayerList):
2020-12-19 18:33:07 +08:00
"""ConditionalWaveFlow, a UpsampleNet with a WaveFlow model.
Parameters
----------
upsample_factors : List[int]
Upsample factors for the upsample net.
n_flows : int
Number of flows in the WaveFlow model.
n_layers : int
Number of ResidualBlocks in each Flow.
n_group : int
Number of timesteps to fold as a group.
channels : int
Feature size of each ResidualBlock.
n_mels : int
Feature size of mel spectrogram (mel bands).
kernel_size : Union[int, List[int]]
Kernel size of the convolution layer in each ResidualBlock.
"""
def __init__(self,
upsample_factors: List[int],
n_flows: int,
n_layers: int,
n_group: int,
channels: int,
n_mels: int,
kernel_size: Union[int, List[int]]):
super(ConditionalWaveFlow, self).__init__()
self.encoder = UpsampleNet(upsample_factors)
self.decoder = WaveFlow(
n_flows=n_flows,
n_layers=n_layers,
n_group=n_group,
channels=channels,
mel_bands=n_mels,
kernel_size=kernel_size)
def forward(self, audio, mel):
2020-12-19 18:33:07 +08:00
"""Compute the transformed random variable z (x to z) and the log of
the determinant of the jacobian of the transformation from x to z.
Parameters
----------
audio : Tensor [shape=(B, T)]
The audio.
mel : Tensor [shape=(B, C_mel, T_mel)]
The mel spectrogram.
Returns
-------
z : Tensor [shape=(B, T)]
The inversely transformed random variable z (x to z)
log_det_jacobian: Tensor [shape=(1,)]
the log of the determinant of the jacobian of the transformation
from x to z.
"""
condition = self.encoder(mel)
z, log_det_jacobian = self.decoder(audio, condition)
return z, log_det_jacobian
2020-12-11 19:45:49 +08:00
@paddle.no_grad()
def infer(self, mel):
2020-12-19 18:33:07 +08:00
r"""Generate raw audio given mel spectrogram.
Parameters
----------
mel : Tensor [shape=(B, C_mel, T_mel)]
Mel spectrogram (in log-magnitude).
Returns
-------
Tensor : [shape=(B, T)]
The synthesized audio, where``T <= T_mel \* upsample_factors``.
"""
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
batch_size, _, time_steps = condition.shape
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
x = self.decoder.inverse(z, condition)
return x
@paddle.no_grad()
def predict(self, mel):
2020-12-19 18:33:07 +08:00
"""Generate raw audio given mel spectrogram.
Parameters
----------
mel : np.ndarray [shape=(C_mel, T_mel)]
Mel spectrogram of an utterance(in log-magnitude).
Returns
-------
np.ndarray [shape=(T,)]
The synthesized audio.
"""
mel = paddle.to_tensor(mel)
mel = paddle.unsqueeze(mel, 0)
audio = self.infer(mel)
audio = audio[0].numpy()
return audio
@classmethod
def from_pretrained(cls, config, checkpoint_path):
2020-12-19 18:33:07 +08:00
"""Build a ConditionalWaveFlow model from a pretrained model.
Parameters
----------
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
ConditionalWaveFlow
The model built from pretrained result.
"""
model = cls(
upsample_factors=config.model.upsample_factors,
n_flows=config.model.n_flows,
n_layers=config.model.n_layers,
n_group=config.model.n_group,
channels=config.model.channels,
n_mels=config.data.n_mels,
kernel_size=config.model.kernel_size)
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
return model
class WaveFlowLoss(nn.Layer):
2020-12-19 18:33:07 +08:00
"""Criterion of a WaveFlow model.
Parameters
----------
sigma : float
The standard deviation of the gaussian noise used in WaveFlow, by
default 1.0.
"""
def __init__(self, sigma=1.0):
super(WaveFlowLoss, self).__init__()
self.sigma = sigma
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
2020-12-19 18:33:07 +08:00
def forward(self, z, log_det_jacobian):
"""Compute the loss given the transformed random variable z and the
log_det_jacobian of transformation from x to z.
2020-12-19 18:33:07 +08:00
Parameters
----------
z : Tensor [shape=(B, T)]
The transformed random variable (x to z).
log_det_jacobian : Tensor [shape=(1,)]
The log of the determinant of the jacobian matrix of the
transformation from x to z.
Returns
-------
Tensor [shape=(1,)]
The loss.
"""
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
loss = loss / np.prod(z.shape)
return loss + self.const