2020-10-10 15:51:54 +08:00
|
|
|
import math
|
|
|
|
import paddle
|
|
|
|
from paddle import nn
|
|
|
|
from paddle.nn import functional as F
|
|
|
|
from paddle.nn import initializer as I
|
|
|
|
|
|
|
|
from typing import Sequence
|
|
|
|
from parakeet.modules import geometry as geo
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
import numpy as np
|
|
|
|
import paddle.fluid.dygraph as dg
|
|
|
|
from paddle import fluid
|
2020-10-16 13:51:56 +08:00
|
|
|
|
|
|
|
__all__ = ["WaveFlow"]
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
def fold(x, n_group):
|
|
|
|
"""Fold audio or spectrogram's temporal dimension in to groups.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
x (Tensor): shape(*, time_steps), the input tensor
|
|
|
|
n_group (int): the size of a group.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: shape(*, time_steps // n_group, group), folded tensor.
|
|
|
|
"""
|
|
|
|
*spatial_shape, time_steps = x.shape
|
|
|
|
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
|
|
|
return paddle.reshape(x, new_shape)
|
|
|
|
|
|
|
|
class UpsampleNet(nn.LayerList):
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
|
|
|
Layer to upsample mel spectrogram to the same temporal resolution with
|
|
|
|
the corresponding waveform. It consists of several conv2dtranspose layers
|
|
|
|
which perform de convolution on mel and time dimension.
|
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
def __init__(self, upsample_factors: Sequence[int]):
|
|
|
|
super(UpsampleNet, self).__init__()
|
|
|
|
for factor in upsample_factors:
|
|
|
|
std = math.sqrt(1 / (3 * 2 * factor))
|
|
|
|
init = I.Uniform(-std, std)
|
|
|
|
self.append(
|
|
|
|
nn.utils.weight_norm(
|
2020-11-04 01:37:49 +08:00
|
|
|
nn.Conv2DTranspose(1, 1, (3, 2 * factor),
|
2020-10-10 15:51:54 +08:00
|
|
|
padding=(1, factor // 2),
|
|
|
|
stride=(1, factor),
|
|
|
|
weight_attr=init,
|
|
|
|
bias_attr=init)))
|
|
|
|
|
|
|
|
# upsample factors
|
|
|
|
self.upsample_factor = np.prod(upsample_factors)
|
|
|
|
self.upsample_factors = upsample_factors
|
|
|
|
|
|
|
|
def forward(self, x, trim_conv_artifact=False):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
x (Tensor): shape(batch_size, input_channels, time_steps), the input
|
|
|
|
spectrogram.
|
|
|
|
trim_conv_artifact (bool, optional): trim deconvolution artifact at
|
|
|
|
each layer. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: shape(batch_size, input_channels, time_steps * upsample_factors).
|
|
|
|
If trim_conv_artifact is True, the output time steps is less
|
|
|
|
than time_steps * upsample_factors.
|
|
|
|
"""
|
|
|
|
x = paddle.unsqueeze(x, 1)
|
|
|
|
for layer in self:
|
|
|
|
x = layer(x)
|
|
|
|
if trim_conv_artifact:
|
|
|
|
time_cutoff = layer._kernel_size[1] - layer._stride[1]
|
|
|
|
x = x[:, :, :, -time_cutoff]
|
|
|
|
x = F.leaky_relu(x, 0.4)
|
|
|
|
x = paddle.squeeze(x, 1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Layer):
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
|
|
|
ResidualBlock that merges infomation from the condition and outputs residual
|
|
|
|
and skip. It has a conv2d layer, which has causal padding in height dimension
|
|
|
|
and same paddign in width dimension.
|
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
|
|
|
super(ResidualBlock, self).__init__()
|
|
|
|
# input conv
|
|
|
|
std = math.sqrt(1 / channels * np.prod(kernel_size))
|
|
|
|
init = I.Uniform(-std, std)
|
2020-11-04 01:37:49 +08:00
|
|
|
receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)]
|
|
|
|
rh, rw = receptive_field
|
|
|
|
paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same
|
|
|
|
conv = nn.Conv2D(channels, 2 * channels, kernel_size,
|
|
|
|
padding=paddings,
|
|
|
|
dilation=dilations,
|
|
|
|
weight_attr=init,
|
|
|
|
bias_attr=init)
|
2020-10-10 15:51:54 +08:00
|
|
|
self.conv = nn.utils.weight_norm(conv)
|
2020-11-04 19:31:36 +08:00
|
|
|
self.rh = rh
|
|
|
|
self.rw = rw
|
|
|
|
self.dilations = dilations
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
# condition projection
|
|
|
|
std = math.sqrt(1 / cond_channels)
|
|
|
|
init = I.Uniform(-std, std)
|
2020-11-04 01:37:49 +08:00
|
|
|
condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1),
|
2020-10-10 15:51:54 +08:00
|
|
|
weight_attr=init, bias_attr=init)
|
|
|
|
self.condition_proj = nn.utils.weight_norm(condition_proj)
|
|
|
|
|
|
|
|
# parametric residual & skip connection
|
|
|
|
std = math.sqrt(1 / channels)
|
|
|
|
init = I.Uniform(-std, std)
|
2020-11-04 01:37:49 +08:00
|
|
|
out_proj = nn.Conv2D(channels, 2 * channels, (1, 1),
|
|
|
|
weight_attr=init, bias_attr=init)
|
2020-10-10 15:51:54 +08:00
|
|
|
self.out_proj = nn.utils.weight_norm(out_proj)
|
|
|
|
|
|
|
|
def forward(self, x, condition):
|
2020-11-04 01:37:49 +08:00
|
|
|
x = self.conv(x)
|
2020-10-10 15:51:54 +08:00
|
|
|
x += self.condition_proj(condition)
|
|
|
|
|
|
|
|
content, gate = paddle.chunk(x, 2, axis=1)
|
|
|
|
x = paddle.tanh(content) * F.sigmoid(gate)
|
|
|
|
|
|
|
|
x = self.out_proj(x)
|
|
|
|
res, skip = paddle.chunk(x, 2, axis=1)
|
|
|
|
return res, skip
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def start_sequence(self):
|
|
|
|
if self.training:
|
|
|
|
raise ValueError("Only use start sequence at evaluation mode.")
|
|
|
|
self._conv_buffer = None
|
|
|
|
|
|
|
|
def add_input(self, x_row, condition_row):
|
|
|
|
if self._conv_buffer is None:
|
|
|
|
self._init_buffer(x_row)
|
|
|
|
self._update_buffer(x_row)
|
|
|
|
|
|
|
|
rw = self.rw
|
|
|
|
x_row = F.conv2d(
|
|
|
|
self._conv_buffer,
|
|
|
|
self.conv.weight,
|
|
|
|
self.conv.bias,
|
|
|
|
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
|
|
|
dilation=self.dilations)
|
|
|
|
x_row += self.condition_proj(condition_row)
|
|
|
|
|
|
|
|
content, gate = paddle.chunk(x_row, 2, axis=1)
|
|
|
|
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
|
|
|
|
|
|
|
x_row = self.out_proj(x_row)
|
|
|
|
res, skip = paddle.chunk(x_row, 2, axis=1)
|
|
|
|
return res, skip
|
|
|
|
|
|
|
|
def _init_buffer(self, input):
|
|
|
|
batch_size, channels, _, width = input.shape
|
|
|
|
self._conv_buffer = paddle.zeros(
|
|
|
|
[batch_size, channels, self.rh, width], dtype=input.dtype)
|
|
|
|
|
|
|
|
def _update_buffer(self, input):
|
|
|
|
self._conv_buffer = paddle.concat(
|
|
|
|
[self._conv_buffer[:, :, 1:, :], input], axis=2)
|
|
|
|
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
class ResidualNet(nn.LayerList):
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
|
|
|
A stack of several ResidualBlocks. It merges condition at each layer. All
|
|
|
|
skip outputs are collected.
|
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
def __init__(self, n_layer, residual_channels, condition_channels, kernel_size, dilations_h):
|
|
|
|
if len(dilations_h) != n_layer:
|
|
|
|
raise ValueError("number of dilations_h should equals num of layers")
|
|
|
|
super(ResidualNet, self).__init__()
|
|
|
|
for i in range(n_layer):
|
|
|
|
dilation = (dilations_h[i], 2 ** i)
|
|
|
|
layer = ResidualBlock(residual_channels, condition_channels, kernel_size, dilation)
|
|
|
|
self.append(layer)
|
|
|
|
|
|
|
|
def forward(self, x, condition):
|
|
|
|
skip_connections = []
|
|
|
|
for layer in self:
|
|
|
|
x, skip = layer(x, condition)
|
|
|
|
skip_connections.append(skip)
|
|
|
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
|
|
|
return out
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def start_sequence(self):
|
|
|
|
for layer in self:
|
|
|
|
layer.start_sequence()
|
|
|
|
|
|
|
|
def add_input(self, x_row, condition_row):
|
|
|
|
# in reversed order
|
|
|
|
skip_connections = []
|
|
|
|
for layer in self:
|
|
|
|
x_row, skip = layer.add_input(x_row, condition_row)
|
|
|
|
skip_connections.append(skip)
|
|
|
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
|
|
|
return out
|
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
class Flow(nn.Layer):
|
2020-11-04 01:37:49 +08:00
|
|
|
"""
|
|
|
|
A Layer that merges the condition and predict scale and bias given a random
|
|
|
|
variable X.
|
|
|
|
"""
|
2020-10-10 15:51:54 +08:00
|
|
|
dilations_dict = {
|
|
|
|
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
|
|
|
16: [1, 1, 1, 1, 1, 1, 1, 1],
|
|
|
|
32: [1, 2, 4, 1, 2, 4, 1, 2],
|
|
|
|
64: [1, 2, 4, 8, 16, 1, 2, 4],
|
|
|
|
128: [1, 2, 4, 8, 16, 32, 64, 1]
|
|
|
|
}
|
|
|
|
|
|
|
|
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
|
|
|
super(Flow, self).__init__()
|
|
|
|
# input projection
|
|
|
|
self.first_conv = nn.utils.weight_norm(
|
2020-11-04 01:37:49 +08:00
|
|
|
nn.Conv2D(1, channels, (1, 1),
|
2020-10-10 15:51:54 +08:00
|
|
|
weight_attr=I.Uniform(-1., 1.),
|
|
|
|
bias_attr=I.Uniform(-1., 1.)))
|
|
|
|
|
|
|
|
# residual net
|
|
|
|
self.resnet = ResidualNet(n_layers, channels, mel_bands, kernel_size,
|
|
|
|
self.dilations_dict[n_group])
|
|
|
|
|
|
|
|
# output projection
|
|
|
|
self.last_conv = nn.utils.weight_norm(
|
2020-11-04 01:37:49 +08:00
|
|
|
nn.Conv2D(channels, 2, (1, 1),
|
2020-10-10 15:51:54 +08:00
|
|
|
weight_attr=I.Constant(0.),
|
|
|
|
bias_attr=I.Constant(0.)))
|
2020-11-04 19:31:36 +08:00
|
|
|
|
|
|
|
# specs
|
|
|
|
self.n_group = n_group
|
|
|
|
|
|
|
|
def _predict_parameters(self, x, condition):
|
|
|
|
x = self.first_conv(x)
|
|
|
|
x = self.resnet(x, condition)
|
|
|
|
bijection_params = self.last_conv(x)
|
|
|
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
|
|
|
return logs, b
|
|
|
|
|
|
|
|
def _transform(self, x, logs, b):
|
|
|
|
z_0 = x[:, :, :1, :] # the first row, just copy it
|
|
|
|
z_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
|
|
|
z_out = paddle.concat([z_0, z_out], axis=2)
|
|
|
|
return z_out
|
2020-10-10 15:51:54 +08:00
|
|
|
|
|
|
|
def forward(self, x, condition):
|
2020-11-04 19:31:36 +08:00
|
|
|
# (B, C, H-1, W)
|
|
|
|
logs, b = self._predict_parameters(
|
|
|
|
x[:, :, :-1, :], condition[:, :, 1:, :])
|
|
|
|
z = self._transform(x, logs, b)
|
|
|
|
return z, (logs, b)
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def _predict_row_parameters(self, x_row, condition_row):
|
|
|
|
x_row = self.first_conv(x_row)
|
|
|
|
x_row = self.resnet.add_input(x_row, condition_row)
|
|
|
|
bijection_params = self.last_conv(x_row)
|
|
|
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
|
|
|
return logs, b
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def _inverse_transform_row(self, z_row, logs, b):
|
|
|
|
x_row = (z_row - b) / paddle.exp(logs)
|
|
|
|
return x_row
|
|
|
|
|
|
|
|
def _inverse_row(self, z_row, x_row, condition_row):
|
|
|
|
logs, b = self._predict_row_parameters(x_row, condition_row)
|
|
|
|
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
|
|
|
return x_next_row, (logs, b)
|
|
|
|
|
|
|
|
def start_sequence(self):
|
|
|
|
self.resnet.start_sequence()
|
|
|
|
|
|
|
|
def inverse(self, z, condition):
|
|
|
|
z_0 = z[:, :, :1, :]
|
|
|
|
x = []
|
|
|
|
logs_list = []
|
|
|
|
b_list = []
|
|
|
|
x.append(z_0)
|
|
|
|
|
|
|
|
self.start_sequence()
|
|
|
|
for i in range(1, self.n_group):
|
|
|
|
x_row = x[-1] # actuallt i-1
|
|
|
|
z_row = z[:, :, i:i+1, :]
|
|
|
|
condition_row = condition[:, :, i:i+1, :]
|
|
|
|
|
|
|
|
x_next_row, (logs, b) = self._inverse_row(z_row, x_row, condition_row)
|
|
|
|
x.append(x_next_row)
|
|
|
|
logs_list.append(logs)
|
|
|
|
b_list.append(b)
|
|
|
|
|
|
|
|
x = paddle.concat(x, 2)
|
|
|
|
logs = paddle.concat(logs_list, 2)
|
|
|
|
b = paddle.concat(b_list, 2)
|
|
|
|
|
|
|
|
return x, (logs, b)
|
|
|
|
|
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
class WaveFlow(nn.LayerList):
|
|
|
|
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
|
|
|
if n_group % 2 or n_flows % 2:
|
|
|
|
raise ValueError("number of flows and number of group must be even "
|
|
|
|
"since a permutation along group among flows is used.")
|
|
|
|
super(WaveFlow, self).__init__()
|
2020-11-04 01:37:49 +08:00
|
|
|
for _ in range(n_flows):
|
2020-10-10 15:51:54 +08:00
|
|
|
self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group))
|
|
|
|
|
|
|
|
# permutations in h
|
2020-11-04 01:37:49 +08:00
|
|
|
self.perms = self._create_perm(n_group, n_flows)
|
|
|
|
|
|
|
|
# specs
|
|
|
|
self.n_group = n_group
|
|
|
|
self.n_flows = n_flows
|
|
|
|
|
|
|
|
def _create_perm(self, n_group, n_flows):
|
2020-10-10 15:51:54 +08:00
|
|
|
indices = list(range(n_group))
|
|
|
|
half = n_group // 2
|
2020-11-04 01:37:49 +08:00
|
|
|
perms = []
|
2020-10-10 15:51:54 +08:00
|
|
|
for i in range(n_flows):
|
|
|
|
if i < n_flows // 2:
|
2020-11-04 01:37:49 +08:00
|
|
|
perms.append(indices[::-1])
|
2020-10-10 15:51:54 +08:00
|
|
|
else:
|
|
|
|
perm = list(reversed(indices[:half])) + list(reversed(indices[half:]))
|
2020-11-04 01:37:49 +08:00
|
|
|
perms.append(perm)
|
|
|
|
return perms
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 01:37:49 +08:00
|
|
|
def _trim(self, x, condition):
|
2020-10-10 15:51:54 +08:00
|
|
|
assert condition.shape[-1] >= x.shape[-1]
|
|
|
|
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
|
|
|
|
|
|
|
|
if x.shape[-1] > pruned_len:
|
|
|
|
x = x[:, :pruned_len]
|
|
|
|
if condition.shape[-1] > pruned_len:
|
|
|
|
condition = condition[:, :, :pruned_len]
|
|
|
|
return x, condition
|
|
|
|
|
|
|
|
def forward(self, x, condition):
|
|
|
|
# x: (B, T)
|
|
|
|
# condition: (B, C, T) upsampled condition
|
2020-11-04 01:37:49 +08:00
|
|
|
x, condition = self._trim(x, condition)
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
# to (B, C, h, T//h) layout
|
2020-10-10 15:51:54 +08:00
|
|
|
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
|
|
|
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
|
|
|
|
|
|
|
# flows
|
|
|
|
logs_list = []
|
|
|
|
for i, layer in enumerate(self):
|
2020-11-04 19:31:36 +08:00
|
|
|
x, (logs, b) = layer(x, condition)
|
2020-10-10 15:51:54 +08:00
|
|
|
logs_list.append(logs)
|
|
|
|
# permute paddle has no shuffle dim
|
|
|
|
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
|
|
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
2020-11-04 01:37:49 +08:00
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
z = paddle.squeeze(x, 1)
|
|
|
|
return z, logs_list
|
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
def start_sequence(self):
|
|
|
|
for layer in self:
|
|
|
|
layer.start_sequence()
|
|
|
|
|
|
|
|
def inverse(self, z, condition):
|
|
|
|
self.start_sequence()
|
|
|
|
|
|
|
|
z, condition = self._trim(z, condition)
|
|
|
|
# to (B, C, h, T//h) layout
|
|
|
|
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
|
|
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
|
|
|
|
|
|
|
# reverse it flow by flow
|
|
|
|
self.n_flows
|
|
|
|
for i in reversed(range(self.n_flows)):
|
|
|
|
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
|
|
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
|
|
|
z, (logs, b) = self[i].inverse(z, condition)
|
|
|
|
x = paddle.squeeze(z, 1)
|
|
|
|
return x
|
|
|
|
|
2020-10-10 15:51:54 +08:00
|
|
|
|
2020-11-04 19:31:36 +08:00
|
|
|
# TODO(chenfeiyu): WaveFlowLoss
|