2020-10-10 15:51:54 +08:00
|
|
|
import math
|
2020-11-09 15:46:27 +08:00
|
|
|
import numpy as np
|
2020-12-19 18:33:07 +08:00
|
|
|
from typing import List, Union, Tuple
|
2020-10-10 15:51:54 +08:00
|
|
|
import paddle
|
|
|
|
from paddle import nn
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
2020-12-12 18:21:20 +08:00
|
|
|
from parakeet.utils import checkpoint
|
2020-10-10 15:51:54 +08:00
|
|
|
from parakeet.modules import geometry as geo
|
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
__all__ = ["WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
def fold(x, n_group):
|
2020-12-19 18:33:07 +08:00
|
|
|
r"""Fold audio or spectrogram's temporal dimension in to groups.
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
x : Tensor [shape=(\*, time_steps)
|
|
|
|
The input tensor.
|
|
|
|
|
|
|
|
n_group : int
|
|
|
|
The size of a group.
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
---------
|
2020-12-19 20:08:25 +08:00
|
|
|
Tensor : [shape=(\*, time_steps // n_group, group)]
|
2020-12-19 18:33:07 +08:00
|
|
|
Folded tensor.
|
2020-10-10 15:51:54 +08:00
|
|
|
"""
|
|
|
|
*spatial_shape, time_steps = x.shape
|
|
|
|
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
|
|
|
return paddle.reshape(x, new_shape)
|
|
|
|
|
|
|
|
class UpsampleNet(nn.LayerList):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Layer to upsample mel spectrogram to the same temporal resolution with
|
|
|
|
the corresponding waveform.
|
|
|
|
|
|
|
|
It consists of several conv2dtranspose layers which perform deconvolution
|
|
|
|
on mel and time dimension.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
upscale_factors : List[int], optional
|
|
|
|
Time upsampling factors for each Conv2DTranspose Layer.
|
|
|
|
|
|
|
|
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
|
|
|
|
Layers. Each upscale_factor is used as the ``stride`` for the
|
|
|
|
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
|
|
|
|
upsampling factor is 256.
|
|
|
|
|
|
|
|
Notes
|
|
|
|
------
|
|
|
|
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
|
|
|
|
transformation used to extract spectrogram features from audio.
|
|
|
|
|
|
|
|
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
|
|
|
|
transformation whose ``hop_length`` equals 256 is suitable.
|
|
|
|
|
|
|
|
See Also
|
|
|
|
---------
|
|
|
|
``librosa.core.stft``
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
2020-11-09 15:46:27 +08:00
|
|
|
def __init__(self, upsample_factors):
|
2020-10-10 15:51:54 +08:00
|
|
|
super(UpsampleNet, self).__init__()
|
|
|
|
for factor in upsample_factors:
|
|
|
|
std = math.sqrt(1 / (3 * 2 * factor))
|
|
|
|
init = I.Uniform(-std, std)
|
|
|
|
self.append(
|
|
|
|
nn.utils.weight_norm(
|
2020-11-04 01:37:49 +08:00
|
|
|
nn.Conv2DTranspose(1, 1, (3, 2 * factor),
|
2020-10-10 15:51:54 +08:00
|
|
|
padding=(1, factor // 2),
|
|
|
|
stride=(1, factor),
|
|
|
|
weight_attr=init,
|
|
|
|
bias_attr=init)))
|
|
|
|
|
|
|
|
# upsample factors
|
|
|
|
self.upsample_factor = np.prod(upsample_factors)
|
|
|
|
self.upsample_factors = upsample_factors
|
|
|
|
|
|
|
|
def forward(self, x, trim_conv_artifact=False):
|
2020-12-19 18:33:07 +08:00
|
|
|
r"""Forward pass of the ``UpsampleNet``.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
x : Tensor [shape=(batch_size, input_channels, time_steps)]
|
|
|
|
The input spectrogram.
|
|
|
|
|
|
|
|
trim_conv_artifact : bool, optional
|
|
|
|
Trim deconvolution artifact at each layer. Defaults to False.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
--------
|
|
|
|
Tensor: [shape=(batch_size, input_channels, time_steps \* upsample_factor)]
|
|
|
|
The upsampled spectrogram.
|
|
|
|
|
|
|
|
Notes
|
|
|
|
--------
|
|
|
|
If trim_conv_artifact is ``True``, the output time steps is less
|
|
|
|
than ``time_steps \* upsample_factors``.
|
2020-10-10 15:51:54 +08:00
|
|
|
"""
|
2020-11-09 15:46:27 +08:00
|
|
|
x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
|
2020-10-10 15:51:54 +08:00
|
|
|
for layer in self:
|
|
|
|
x = layer(x)
|
|
|
|
if trim_conv_artifact:
|
|
|
|
time_cutoff = layer._kernel_size[1] - layer._stride[1]
|
2020-11-09 15:46:27 +08:00
|
|
|
x = x[:, :, :, :-time_cutoff]
|
2020-10-10 15:51:54 +08:00
|
|
|
x = F.leaky_relu(x, 0.4)
|
2020-11-09 15:46:27 +08:00
|
|
|
x = paddle.squeeze(x, 1) # back to (B, C, T)
|
2020-10-10 15:51:54 +08:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
#TODO write doc
|
2020-10-10 15:51:54 +08:00
|
|
|
class ResidualBlock(nn.Layer):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""ResidualBlock, the basic unit of ResidualNet used in WaveFlow.
|
|
|
|
|
|
|
|
It has a conv2d layer, which has causal padding in height dimension and
|
|
|
|
same paddign in width dimension. It also has projection for the condition
|
|
|
|
and output.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
channels : int
|
|
|
|
Feature size of the input.
|
|
|
|
|
|
|
|
cond_channels : int
|
|
|
|
Featuer size of the condition.
|
|
|
|
|
|
|
|
kernel_size : Tuple[int]
|
|
|
|
Kernel size of the Convolution2d applied to the input.
|
|
|
|
|
|
|
|
dilations : int
|
|
|
|
Dilations of the Convolution2d applied to the input.
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
|
|
|
super(ResidualBlock, self).__init__()
|
|
|
|
# input conv
|
|
|
|
std = math.sqrt(1 / channels * np.prod(kernel_size))
|
|
|
|
init = I.Uniform(-std, std)
|
2020-11-04 01:37:49 +08:00
|
|
|
receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)]
|
|
|
|
rh, rw = receptive_field
|
|
|
|
paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same
|
|
|
|
conv = nn.Conv2D(channels, 2 * channels, kernel_size,
|
|
|
|
padding=paddings,
|
|
|
|
dilation=dilations,
|
|
|
|
weight_attr=init,
|
|
|
|
bias_attr=init)
|
2020-10-10 15:51:54 +08:00
|
|
|
self.conv = nn.utils.weight_norm(conv)
|
2020-11-04 19:31:36 +08:00
|
|
|
self.rh = rh
|
|
|
|
self.rw = rw
|
|
|
|
self.dilations = dilations
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
# condition projection
|
|
|
|
std = math.sqrt(1 / cond_channels)
|
|
|
|
init = I.Uniform(-std, std)
|
2020-11-04 01:37:49 +08:00
|
|
|
condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1),
|
2020-10-10 15:51:54 +08:00
|
|
|
weight_attr=init, bias_attr=init)
|
|
|
|
self.condition_proj = nn.utils.weight_norm(condition_proj)
|
|
|
|
|
|
|
|
# parametric residual & skip connection
|
|
|
|
std = math.sqrt(1 / channels)
|
|
|
|
init = I.Uniform(-std, std)
|
2020-11-04 01:37:49 +08:00
|
|
|
out_proj = nn.Conv2D(channels, 2 * channels, (1, 1),
|
|
|
|
weight_attr=init, bias_attr=init)
|
2020-10-10 15:51:54 +08:00
|
|
|
self.out_proj = nn.utils.weight_norm(out_proj)
|
|
|
|
|
|
|
|
def forward(self, x, condition):
|
2020-11-04 23:22:45 +08:00
|
|
|
"""Compute output for a whole folded sequence.
|
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
x : Tensor [shape=(batch_size, channel, height, width)]
|
|
|
|
The input.
|
|
|
|
|
|
|
|
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
|
|
|
The local condition.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
res : Tensor [shape=(batch_size, channel, height, width)]
|
|
|
|
The residual output.
|
|
|
|
|
|
|
|
skip : Tensor [shape=(batch_size, channel, height, width)]
|
|
|
|
The skip output.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-11-09 15:46:27 +08:00
|
|
|
x_in = x
|
2020-11-04 01:37:49 +08:00
|
|
|
x = self.conv(x)
|
2020-10-10 15:51:54 +08:00
|
|
|
x += self.condition_proj(condition)
|
|
|
|
|
|
|
|
content, gate = paddle.chunk(x, 2, axis=1)
|
|
|
|
x = paddle.tanh(content) * F.sigmoid(gate)
|
|
|
|
|
|
|
|
x = self.out_proj(x)
|
|
|
|
res, skip = paddle.chunk(x, 2, axis=1)
|
2020-12-19 18:33:07 +08:00
|
|
|
res = x_in + res
|
|
|
|
return res, skip
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def start_sequence(self):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Prepare the layer for incremental computation of causal
|
|
|
|
convolution. Reset the buffer for causal convolution.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If not in evaluation mode.
|
|
|
|
"""
|
2020-11-04 19:31:36 +08:00
|
|
|
if self.training:
|
|
|
|
raise ValueError("Only use start sequence at evaluation mode.")
|
|
|
|
self._conv_buffer = None
|
|
|
|
|
2020-12-09 21:05:39 +08:00
|
|
|
# NOTE: call self.conv's weight norm hook expliccitly since
|
|
|
|
# its weight will be visited directly in `add_input` without
|
|
|
|
# calling its `__call__` method. If we do not trigger the weight
|
|
|
|
# norm hook, the weight may be outdated. e.g. after loading from
|
|
|
|
# a saved checkpoint
|
|
|
|
# see also: https://github.com/pytorch/pytorch/issues/47588
|
|
|
|
for hook in self.conv._forward_pre_hooks.values():
|
|
|
|
hook(self.conv, None)
|
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def add_input(self, x_row, condition_row):
|
2020-11-04 23:22:45 +08:00
|
|
|
"""Compute the output for a row and update the buffer.
|
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
x_row : Tensor [shape=(batch_size, channel, 1, width)]
|
|
|
|
A row of the input.
|
|
|
|
|
|
|
|
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
|
|
|
|
A row of the condition.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
res : Tensor [shape=(batch_size, channel, 1, width)]
|
|
|
|
A row of the the residual output.
|
|
|
|
|
2020-12-19 20:08:25 +08:00
|
|
|
skip : Tensor [shape=(batch_size, channel, 1, width)]
|
2020-12-19 18:33:07 +08:00
|
|
|
A row of the skip output.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-11-09 15:46:27 +08:00
|
|
|
x_row_in = x_row
|
2020-11-04 19:31:36 +08:00
|
|
|
if self._conv_buffer is None:
|
|
|
|
self._init_buffer(x_row)
|
|
|
|
self._update_buffer(x_row)
|
|
|
|
|
|
|
|
rw = self.rw
|
|
|
|
x_row = F.conv2d(
|
|
|
|
self._conv_buffer,
|
|
|
|
self.conv.weight,
|
|
|
|
self.conv.bias,
|
|
|
|
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
|
|
|
dilation=self.dilations)
|
|
|
|
x_row += self.condition_proj(condition_row)
|
|
|
|
|
|
|
|
content, gate = paddle.chunk(x_row, 2, axis=1)
|
|
|
|
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
|
|
|
|
|
|
|
x_row = self.out_proj(x_row)
|
|
|
|
res, skip = paddle.chunk(x_row, 2, axis=1)
|
2020-12-19 18:33:07 +08:00
|
|
|
res = x_row_in + res
|
|
|
|
return res, skip
|
2020-11-04 19:31:36 +08:00
|
|
|
|
|
|
|
def _init_buffer(self, input):
|
|
|
|
batch_size, channels, _, width = input.shape
|
|
|
|
self._conv_buffer = paddle.zeros(
|
|
|
|
[batch_size, channels, self.rh, width], dtype=input.dtype)
|
|
|
|
|
|
|
|
def _update_buffer(self, input):
|
|
|
|
self._conv_buffer = paddle.concat(
|
|
|
|
[self._conv_buffer[:, :, 1:, :], input], axis=2)
|
|
|
|
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
class ResidualNet(nn.LayerList):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""A stack of several ResidualBlocks. It merges condition at each layer.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
n_layer : int
|
|
|
|
Number of ResidualBlocks in the ResidualNet.
|
|
|
|
|
|
|
|
residual_channels : int
|
|
|
|
Feature size of each ResidualBlocks.
|
|
|
|
|
|
|
|
condition_channels : int
|
|
|
|
Feature size of the condition.
|
|
|
|
|
|
|
|
kernel_size : Tuple[int]
|
|
|
|
Kernel size of each ResidualBlock.
|
|
|
|
|
|
|
|
dilations_h : List[int]
|
|
|
|
Dilation in height dimension of every ResidualBlock.
|
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
ValueError
|
|
|
|
If the length of dilations_h does not equals n_layers.
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
2020-12-19 18:33:07 +08:00
|
|
|
def __init__(self,
|
|
|
|
n_layer: int,
|
|
|
|
residual_channels: int,
|
|
|
|
condition_channels: int,
|
|
|
|
kernel_size: Tuple[int],
|
|
|
|
dilations_h: List[int]):
|
2020-10-10 15:51:54 +08:00
|
|
|
if len(dilations_h) != n_layer:
|
|
|
|
raise ValueError("number of dilations_h should equals num of layers")
|
|
|
|
super(ResidualNet, self).__init__()
|
|
|
|
for i in range(n_layer):
|
|
|
|
dilation = (dilations_h[i], 2 ** i)
|
|
|
|
layer = ResidualBlock(residual_channels, condition_channels, kernel_size, dilation)
|
|
|
|
self.append(layer)
|
|
|
|
|
|
|
|
def forward(self, x, condition):
|
2020-11-04 23:22:45 +08:00
|
|
|
"""Comput the output of given the input and the condition.
|
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
x : Tensor [shape=(batch_size, channel, height, width)]
|
|
|
|
The input.
|
|
|
|
|
|
|
|
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
|
|
|
The local condition.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
--------
|
|
|
|
Tensor : [shape=(batch_size, channel, height, width)]
|
|
|
|
The output, which is an aggregation of all the skip outputs.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
skip_connections = []
|
|
|
|
for layer in self:
|
|
|
|
x, skip = layer(x, condition)
|
|
|
|
skip_connections.append(skip)
|
|
|
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
|
|
|
return out
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def start_sequence(self):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Prepare the layer for incremental computation.
|
|
|
|
"""
|
2020-11-04 19:31:36 +08:00
|
|
|
for layer in self:
|
|
|
|
layer.start_sequence()
|
|
|
|
|
|
|
|
def add_input(self, x_row, condition_row):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Compute the output for a row and update the buffers.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
x_row : Tensor [shape=(batch_size, channel, 1, width)]
|
|
|
|
A row of the input.
|
|
|
|
|
|
|
|
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
|
|
|
|
A row of the condition.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
res : Tensor [shape=(batch_size, channel, 1, width)]
|
|
|
|
A row of the the residual output.
|
|
|
|
|
2020-12-19 20:08:25 +08:00
|
|
|
skip : Tensor [shape=(batch_size, channel, 1, width)]
|
2020-12-19 18:33:07 +08:00
|
|
|
A row of the skip output.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-11-04 19:31:36 +08:00
|
|
|
skip_connections = []
|
|
|
|
for layer in self:
|
|
|
|
x_row, skip = layer.add_input(x_row, condition_row)
|
|
|
|
skip_connections.append(skip)
|
|
|
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
|
|
|
return out
|
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
class Flow(nn.Layer):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""A bijection (Reversable layer) that transform a density of latent
|
|
|
|
variables p(Z) into a complex data distribution p(X).
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 20:08:25 +08:00
|
|
|
It's an auto regressive flow. The ``forward`` method implements the
|
|
|
|
probability density estimation. The ``inverse`` method implements the
|
2020-12-19 18:33:07 +08:00
|
|
|
sampling.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
n_layers : int
|
|
|
|
Number of ResidualBlocks in the Flow.
|
|
|
|
|
|
|
|
channels : int
|
|
|
|
Feature size of the ResidualBlocks.
|
|
|
|
|
|
|
|
mel_bands : int
|
|
|
|
Feature size of the mel spectrogram (mel bands).
|
|
|
|
|
|
|
|
kernel_size : Tuple[int]
|
|
|
|
Kernel size of each ResisualBlocks in the Flow.
|
|
|
|
|
|
|
|
n_group : int
|
|
|
|
Number of timesteps to the folded into a group.
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
dilations_dict = {
|
|
|
|
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
|
|
|
16: [1, 1, 1, 1, 1, 1, 1, 1],
|
|
|
|
32: [1, 2, 4, 1, 2, 4, 1, 2],
|
|
|
|
64: [1, 2, 4, 8, 16, 1, 2, 4],
|
|
|
|
128: [1, 2, 4, 8, 16, 32, 64, 1]
|
|
|
|
}
|
|
|
|
|
|
|
|
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
|
|
|
super(Flow, self).__init__()
|
|
|
|
# input projection
|
2020-11-09 15:46:27 +08:00
|
|
|
self.input_proj = nn.utils.weight_norm(
|
2020-11-04 01:37:49 +08:00
|
|
|
nn.Conv2D(1, channels, (1, 1),
|
2020-10-10 15:51:54 +08:00
|
|
|
weight_attr=I.Uniform(-1., 1.),
|
|
|
|
bias_attr=I.Uniform(-1., 1.)))
|
|
|
|
|
|
|
|
# residual net
|
|
|
|
self.resnet = ResidualNet(n_layers, channels, mel_bands, kernel_size,
|
|
|
|
self.dilations_dict[n_group])
|
|
|
|
|
|
|
|
# output projection
|
2020-11-09 15:46:27 +08:00
|
|
|
self.output_proj = nn.Conv2D(channels, 2, (1, 1),
|
|
|
|
weight_attr=I.Constant(0.),
|
|
|
|
bias_attr=I.Constant(0.))
|
2020-11-04 19:31:36 +08:00
|
|
|
|
|
|
|
# specs
|
|
|
|
self.n_group = n_group
|
|
|
|
|
|
|
|
def _predict_parameters(self, x, condition):
|
2020-11-09 15:46:27 +08:00
|
|
|
x = self.input_proj(x)
|
2020-11-04 19:31:36 +08:00
|
|
|
x = self.resnet(x, condition)
|
2020-11-09 15:46:27 +08:00
|
|
|
bijection_params = self.output_proj(x)
|
2020-11-04 19:31:36 +08:00
|
|
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
|
|
|
return logs, b
|
|
|
|
|
|
|
|
def _transform(self, x, logs, b):
|
|
|
|
z_0 = x[:, :, :1, :] # the first row, just copy it
|
|
|
|
z_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
|
|
|
z_out = paddle.concat([z_0, z_out], axis=2)
|
|
|
|
return z_out
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
def forward(self, x, condition):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Probability density estimation. It is done by inversely transform
|
|
|
|
a sample from p(X) into a sample from p(Z).
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
x : Tensor [shape=(batch, 1, height, width)]
|
|
|
|
A input sample of the distribution p(X).
|
|
|
|
|
|
|
|
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
|
|
|
The local condition.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
--------
|
|
|
|
z (Tensor): shape(batch, 1, height, width), the transformed sample.
|
|
|
|
|
|
|
|
Tuple[Tensor, Tensor]
|
|
|
|
The parameter of the transformation.
|
|
|
|
|
|
|
|
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
|
|
|
of the transformation from x to z.
|
|
|
|
|
|
|
|
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
|
|
|
transformation from x to z.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-11-04 19:31:36 +08:00
|
|
|
# (B, C, H-1, W)
|
|
|
|
logs, b = self._predict_parameters(
|
|
|
|
x[:, :, :-1, :], condition[:, :, 1:, :])
|
|
|
|
z = self._transform(x, logs, b)
|
|
|
|
return z, (logs, b)
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def _predict_row_parameters(self, x_row, condition_row):
|
2020-11-09 15:46:27 +08:00
|
|
|
x_row = self.input_proj(x_row)
|
2020-11-04 19:31:36 +08:00
|
|
|
x_row = self.resnet.add_input(x_row, condition_row)
|
2020-11-09 15:46:27 +08:00
|
|
|
bijection_params = self.output_proj(x_row)
|
2020-11-04 19:31:36 +08:00
|
|
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
|
|
|
return logs, b
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def _inverse_transform_row(self, z_row, logs, b):
|
2020-11-09 15:46:27 +08:00
|
|
|
x_row = (z_row - b) * paddle.exp(-logs)
|
2020-11-04 19:31:36 +08:00
|
|
|
return x_row
|
|
|
|
|
|
|
|
def _inverse_row(self, z_row, x_row, condition_row):
|
|
|
|
logs, b = self._predict_row_parameters(x_row, condition_row)
|
|
|
|
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
|
|
|
return x_next_row, (logs, b)
|
|
|
|
|
2020-11-09 15:46:27 +08:00
|
|
|
def _start_sequence(self):
|
2020-11-04 19:31:36 +08:00
|
|
|
self.resnet.start_sequence()
|
|
|
|
|
|
|
|
def inverse(self, z, condition):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Sampling from the the distrition p(X). It is done by sample form
|
|
|
|
p(Z) and transform the sample. It is a auto regressive transformation.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
z : Tensor [shape=(batch, 1, height, width)]
|
|
|
|
A sample of the distribution p(Z).
|
|
|
|
|
|
|
|
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
|
|
|
The local condition.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
---------
|
|
|
|
x : Tensor [shape=(batch, 1, height, width)]
|
|
|
|
The transformed sample.
|
|
|
|
|
|
|
|
Tuple[Tensor, Tensor]
|
|
|
|
The parameter of the transformation.
|
|
|
|
|
|
|
|
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
|
|
|
of the transformation from x to z.
|
|
|
|
|
|
|
|
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
|
|
|
transformation from x to z.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-11-04 19:31:36 +08:00
|
|
|
z_0 = z[:, :, :1, :]
|
|
|
|
x = []
|
|
|
|
logs_list = []
|
|
|
|
b_list = []
|
|
|
|
x.append(z_0)
|
|
|
|
|
2020-11-09 15:46:27 +08:00
|
|
|
self._start_sequence()
|
2020-11-04 19:31:36 +08:00
|
|
|
for i in range(1, self.n_group):
|
2020-11-09 15:46:27 +08:00
|
|
|
x_row = x[-1] # actuallt i-1:i
|
2020-11-04 19:31:36 +08:00
|
|
|
z_row = z[:, :, i:i+1, :]
|
|
|
|
condition_row = condition[:, :, i:i+1, :]
|
|
|
|
|
|
|
|
x_next_row, (logs, b) = self._inverse_row(z_row, x_row, condition_row)
|
|
|
|
x.append(x_next_row)
|
|
|
|
logs_list.append(logs)
|
|
|
|
b_list.append(b)
|
|
|
|
|
|
|
|
x = paddle.concat(x, 2)
|
|
|
|
logs = paddle.concat(logs_list, 2)
|
|
|
|
b = paddle.concat(b_list, 2)
|
|
|
|
return x, (logs, b)
|
|
|
|
|
2020-11-09 15:46:27 +08:00
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
class WaveFlow(nn.LayerList):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""An Deep Reversible layer that is composed of severel auto regressive
|
|
|
|
flows.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
n_flows : int
|
|
|
|
Number of flows in the WaveFlow model.
|
|
|
|
|
|
|
|
n_layers : int
|
|
|
|
Number of ResidualBlocks in each Flow.
|
|
|
|
|
|
|
|
n_group : int
|
|
|
|
Number of timesteps to fold as a group.
|
|
|
|
|
|
|
|
channels : int
|
|
|
|
Feature size of each ResidualBlock.
|
|
|
|
|
|
|
|
mel_bands : int
|
|
|
|
Feature size of mel spectrogram (mel bands).
|
|
|
|
|
|
|
|
kernel_size : Union[int, List[int]]
|
|
|
|
Kernel size of the convolution layer in each ResidualBlock.
|
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
|
|
|
if n_group % 2 or n_flows % 2:
|
|
|
|
raise ValueError("number of flows and number of group must be even "
|
|
|
|
"since a permutation along group among flows is used.")
|
|
|
|
super(WaveFlow, self).__init__()
|
2020-11-04 01:37:49 +08:00
|
|
|
for _ in range(n_flows):
|
2020-10-10 15:51:54 +08:00
|
|
|
self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group))
|
|
|
|
|
|
|
|
# permutations in h
|
2020-11-04 01:37:49 +08:00
|
|
|
self.perms = self._create_perm(n_group, n_flows)
|
|
|
|
|
|
|
|
# specs
|
|
|
|
self.n_group = n_group
|
|
|
|
self.n_flows = n_flows
|
|
|
|
|
|
|
|
def _create_perm(self, n_group, n_flows):
|
2020-10-10 15:51:54 +08:00
|
|
|
indices = list(range(n_group))
|
|
|
|
half = n_group // 2
|
2020-11-04 01:37:49 +08:00
|
|
|
perms = []
|
2020-10-10 15:51:54 +08:00
|
|
|
for i in range(n_flows):
|
|
|
|
if i < n_flows // 2:
|
2020-11-04 01:37:49 +08:00
|
|
|
perms.append(indices[::-1])
|
2020-10-10 15:51:54 +08:00
|
|
|
else:
|
|
|
|
perm = list(reversed(indices[:half])) + list(reversed(indices[half:]))
|
2020-11-04 01:37:49 +08:00
|
|
|
perms.append(perm)
|
|
|
|
return perms
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 01:37:49 +08:00
|
|
|
def _trim(self, x, condition):
|
2020-10-10 15:51:54 +08:00
|
|
|
assert condition.shape[-1] >= x.shape[-1]
|
|
|
|
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
|
|
|
|
|
|
|
|
if x.shape[-1] > pruned_len:
|
|
|
|
x = x[:, :pruned_len]
|
|
|
|
if condition.shape[-1] > pruned_len:
|
|
|
|
condition = condition[:, :, :pruned_len]
|
|
|
|
return x, condition
|
|
|
|
|
|
|
|
def forward(self, x, condition):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Probability density estimation of random variable x given the
|
|
|
|
condition.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
-----------
|
|
|
|
x : Tensor [shape=(batch_size, time_steps)]
|
|
|
|
The audio.
|
|
|
|
|
|
|
|
condition : Tensor [shape=(batch_size, condition channel, time_steps)]
|
|
|
|
The local condition (mel spectrogram here).
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
--------
|
|
|
|
z : Tensor [shape=(batch_size, time_steps)]
|
|
|
|
The transformed random variable.
|
|
|
|
|
|
|
|
log_det_jacobian: Tensor [shape=(1,)]
|
|
|
|
The log determinant of the jacobian of the transformation from x
|
|
|
|
to z.
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
# x: (B, T)
|
|
|
|
# condition: (B, C, T) upsampled condition
|
2020-11-04 01:37:49 +08:00
|
|
|
x, condition = self._trim(x, condition)
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
# to (B, C, h, T//h) layout
|
2020-10-10 15:51:54 +08:00
|
|
|
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
|
|
|
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
|
|
|
|
|
|
|
# flows
|
|
|
|
logs_list = []
|
|
|
|
for i, layer in enumerate(self):
|
2020-11-04 19:31:36 +08:00
|
|
|
x, (logs, b) = layer(x, condition)
|
2020-10-10 15:51:54 +08:00
|
|
|
logs_list.append(logs)
|
|
|
|
# permute paddle has no shuffle dim
|
|
|
|
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
|
|
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-11-04 23:22:45 +08:00
|
|
|
z = paddle.squeeze(x, 1) # (B, H, W)
|
|
|
|
batch_size = z.shape[0]
|
|
|
|
z = paddle.reshape(paddle.transpose(z, [0, 2, 1]), [batch_size, -1])
|
|
|
|
|
|
|
|
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
|
|
|
|
return z, log_det_jacobian
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def inverse(self, z, condition):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Sampling from the the distrition p(X).
|
|
|
|
|
|
|
|
It is done by sample a ``z`` form p(Z) and transform it into ``x``.
|
|
|
|
Each Flow transform .. math:: `z_{i-1}` to .. math:: `z_{i}` in an
|
|
|
|
autoregressive manner.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
z : Tensor [shape=(batch, 1, time_steps]
|
|
|
|
A sample of the distribution p(Z).
|
|
|
|
|
|
|
|
condition : Tensor [shape=(batch, condition_channel, time_steps)]
|
|
|
|
The local condition.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Returns
|
|
|
|
--------
|
|
|
|
x : Tensor [shape=(batch_size, time_steps)]
|
|
|
|
The transformed sample (audio here).
|
2020-11-04 23:22:45 +08:00
|
|
|
"""
|
2020-11-04 19:31:36 +08:00
|
|
|
|
|
|
|
z, condition = self._trim(z, condition)
|
|
|
|
# to (B, C, h, T//h) layout
|
|
|
|
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
|
|
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
2020-11-09 15:46:27 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
# reverse it flow by flow
|
|
|
|
for i in reversed(range(self.n_flows)):
|
|
|
|
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
|
|
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
|
|
|
z, (logs, b) = self[i].inverse(z, condition)
|
2020-11-04 23:22:45 +08:00
|
|
|
|
|
|
|
x = paddle.squeeze(z, 1) # (B, H, W)
|
|
|
|
batch_size = x.shape[0]
|
|
|
|
x = paddle.reshape(paddle.transpose(x, [0, 2, 1]), [batch_size, -1])
|
2020-11-04 19:31:36 +08:00
|
|
|
return x
|
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-09 15:46:27 +08:00
|
|
|
class ConditionalWaveFlow(nn.LayerList):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""ConditionalWaveFlow, a UpsampleNet with a WaveFlow model.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
upsample_factors : List[int]
|
|
|
|
Upsample factors for the upsample net.
|
|
|
|
|
|
|
|
n_flows : int
|
|
|
|
Number of flows in the WaveFlow model.
|
|
|
|
|
|
|
|
n_layers : int
|
|
|
|
Number of ResidualBlocks in each Flow.
|
|
|
|
|
|
|
|
n_group : int
|
|
|
|
Number of timesteps to fold as a group.
|
|
|
|
|
|
|
|
channels : int
|
|
|
|
Feature size of each ResidualBlock.
|
|
|
|
|
|
|
|
n_mels : int
|
|
|
|
Feature size of mel spectrogram (mel bands).
|
|
|
|
|
|
|
|
kernel_size : Union[int, List[int]]
|
|
|
|
Kernel size of the convolution layer in each ResidualBlock.
|
|
|
|
"""
|
2020-12-12 18:21:20 +08:00
|
|
|
def __init__(self,
|
|
|
|
upsample_factors: List[int],
|
|
|
|
n_flows: int,
|
|
|
|
n_layers: int,
|
|
|
|
n_group: int,
|
|
|
|
channels: int,
|
|
|
|
n_mels: int,
|
|
|
|
kernel_size: Union[int, List[int]]):
|
2020-11-09 15:46:27 +08:00
|
|
|
super(ConditionalWaveFlow, self).__init__()
|
2020-12-12 18:21:20 +08:00
|
|
|
self.encoder = UpsampleNet(upsample_factors)
|
|
|
|
self.decoder = WaveFlow(
|
|
|
|
n_flows=n_flows,
|
|
|
|
n_layers=n_layers,
|
|
|
|
n_group=n_group,
|
|
|
|
channels=channels,
|
|
|
|
mel_bands=n_mels,
|
|
|
|
kernel_size=kernel_size)
|
2020-11-09 15:46:27 +08:00
|
|
|
|
|
|
|
def forward(self, audio, mel):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Compute the transformed random variable z (x to z) and the log of
|
|
|
|
the determinant of the jacobian of the transformation from x to z.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
audio : Tensor [shape=(B, T)]
|
|
|
|
The audio.
|
|
|
|
|
|
|
|
mel : Tensor [shape=(B, C_mel, T_mel)]
|
|
|
|
The mel spectrogram.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
z : Tensor [shape=(B, T)]
|
|
|
|
The inversely transformed random variable z (x to z)
|
|
|
|
|
|
|
|
log_det_jacobian: Tensor [shape=(1,)]
|
|
|
|
the log of the determinant of the jacobian of the transformation
|
|
|
|
from x to z.
|
|
|
|
"""
|
2020-11-09 15:46:27 +08:00
|
|
|
condition = self.encoder(mel)
|
|
|
|
z, log_det_jacobian = self.decoder(audio, condition)
|
|
|
|
return z, log_det_jacobian
|
|
|
|
|
2020-12-11 19:45:49 +08:00
|
|
|
@paddle.no_grad()
|
2020-12-12 18:21:20 +08:00
|
|
|
def infer(self, mel):
|
2020-12-19 18:33:07 +08:00
|
|
|
r"""Generate raw audio given mel spectrogram.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
mel : Tensor [shape=(B, C_mel, T_mel)]
|
|
|
|
Mel spectrogram (in log-magnitude).
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Tensor : [shape=(B, T)]
|
|
|
|
The synthesized audio, where``T <= T_mel \* upsample_factors``.
|
|
|
|
"""
|
2020-11-09 15:46:27 +08:00
|
|
|
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
|
|
|
batch_size, _, time_steps = condition.shape
|
|
|
|
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
|
|
|
x = self.decoder.inverse(z, condition)
|
|
|
|
return x
|
2020-12-12 18:21:20 +08:00
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
def predict(self, mel):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Generate raw audio given mel spectrogram.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
mel : np.ndarray [shape=(C_mel, T_mel)]
|
|
|
|
Mel spectrogram of an utterance(in log-magnitude).
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
np.ndarray [shape=(T,)]
|
|
|
|
The synthesized audio.
|
|
|
|
"""
|
2020-12-12 18:21:20 +08:00
|
|
|
mel = paddle.to_tensor(mel)
|
|
|
|
mel = paddle.unsqueeze(mel, 0)
|
|
|
|
audio = self.infer(mel)
|
|
|
|
audio = audio[0].numpy()
|
|
|
|
return audio
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_pretrained(cls, config, checkpoint_path):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Build a ConditionalWaveFlow model from a pretrained model.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
config: yacs.config.CfgNode
|
|
|
|
model configs
|
|
|
|
|
|
|
|
checkpoint_path: Path or str
|
|
|
|
the path of pretrained model checkpoint, without extension name
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
ConditionalWaveFlow
|
|
|
|
The model built from pretrained result.
|
|
|
|
"""
|
2020-12-12 18:21:20 +08:00
|
|
|
model = cls(
|
|
|
|
upsample_factors=config.model.upsample_factors,
|
|
|
|
n_flows=config.model.n_flows,
|
|
|
|
n_layers=config.model.n_layers,
|
|
|
|
n_group=config.model.n_group,
|
|
|
|
channels=config.model.channels,
|
|
|
|
n_mels=config.data.n_mels,
|
|
|
|
kernel_size=config.model.kernel_size)
|
|
|
|
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
|
|
|
return model
|
2020-11-09 15:46:27 +08:00
|
|
|
|
|
|
|
|
2020-11-04 23:22:45 +08:00
|
|
|
class WaveFlowLoss(nn.Layer):
|
2020-12-19 18:33:07 +08:00
|
|
|
"""Criterion of a WaveFlow model.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
sigma : float
|
|
|
|
The standard deviation of the gaussian noise used in WaveFlow, by
|
|
|
|
default 1.0.
|
|
|
|
"""
|
2020-11-04 23:22:45 +08:00
|
|
|
def __init__(self, sigma=1.0):
|
2020-11-09 15:46:27 +08:00
|
|
|
super(WaveFlowLoss, self).__init__()
|
2020-11-04 23:22:45 +08:00
|
|
|
self.sigma = sigma
|
|
|
|
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
def forward(self, z, log_det_jacobian):
|
|
|
|
"""Compute the loss given the transformed random variable z and the
|
|
|
|
log_det_jacobian of transformation from x to z.
|
2020-11-04 23:22:45 +08:00
|
|
|
|
2020-12-19 18:33:07 +08:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
z : Tensor [shape=(B, T)]
|
|
|
|
The transformed random variable (x to z).
|
|
|
|
|
|
|
|
log_det_jacobian : Tensor [shape=(1,)]
|
|
|
|
The log of the determinant of the jacobian matrix of the
|
|
|
|
transformation from x to z.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
Tensor [shape=(1,)]
|
|
|
|
The loss.
|
|
|
|
"""
|
2020-11-04 23:22:45 +08:00
|
|
|
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
|
|
|
|
loss = loss / np.prod(z.shape)
|
2020-11-09 15:46:27 +08:00
|
|
|
return loss + self.const
|