253 lines
9.2 KiB
Python
253 lines
9.2 KiB
Python
import math
|
|
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
from paddle.nn import initializer as I
|
|
|
|
from typing import Sequence
|
|
from parakeet.modules import geometry as geo
|
|
|
|
import itertools
|
|
import numpy as np
|
|
import paddle.fluid.dygraph as dg
|
|
from paddle import fluid
|
|
|
|
__all__ = ["WaveFlow"]
|
|
|
|
def fold(x, n_group):
|
|
"""Fold audio or spectrogram's temporal dimension in to groups.
|
|
|
|
Args:
|
|
x (Tensor): shape(*, time_steps), the input tensor
|
|
n_group (int): the size of a group.
|
|
|
|
Returns:
|
|
Tensor: shape(*, time_steps // n_group, group), folded tensor.
|
|
"""
|
|
*spatial_shape, time_steps = x.shape
|
|
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
|
return paddle.reshape(x, new_shape)
|
|
|
|
class UpsampleNet(nn.LayerList):
|
|
"""
|
|
Layer to upsample mel spectrogram to the same temporal resolution with
|
|
the corresponding waveform. It consists of several conv2dtranspose layers
|
|
which perform de convolution on mel and time dimension.
|
|
"""
|
|
def __init__(self, upsample_factors: Sequence[int]):
|
|
super(UpsampleNet, self).__init__()
|
|
for factor in upsample_factors:
|
|
std = math.sqrt(1 / (3 * 2 * factor))
|
|
init = I.Uniform(-std, std)
|
|
self.append(
|
|
nn.utils.weight_norm(
|
|
nn.Conv2DTranspose(1, 1, (3, 2 * factor),
|
|
padding=(1, factor // 2),
|
|
stride=(1, factor),
|
|
weight_attr=init,
|
|
bias_attr=init)))
|
|
|
|
# upsample factors
|
|
self.upsample_factor = np.prod(upsample_factors)
|
|
self.upsample_factors = upsample_factors
|
|
|
|
def forward(self, x, trim_conv_artifact=False):
|
|
"""
|
|
Args:
|
|
x (Tensor): shape(batch_size, input_channels, time_steps), the input
|
|
spectrogram.
|
|
trim_conv_artifact (bool, optional): trim deconvolution artifact at
|
|
each layer. Defaults to False.
|
|
|
|
Returns:
|
|
Tensor: shape(batch_size, input_channels, time_steps * upsample_factors).
|
|
If trim_conv_artifact is True, the output time steps is less
|
|
than time_steps * upsample_factors.
|
|
"""
|
|
x = paddle.unsqueeze(x, 1)
|
|
for layer in self:
|
|
x = layer(x)
|
|
if trim_conv_artifact:
|
|
time_cutoff = layer._kernel_size[1] - layer._stride[1]
|
|
x = x[:, :, :, -time_cutoff]
|
|
x = F.leaky_relu(x, 0.4)
|
|
x = paddle.squeeze(x, 1)
|
|
return x
|
|
|
|
|
|
class ResidualBlock(nn.Layer):
|
|
"""
|
|
ResidualBlock that merges infomation from the condition and outputs residual
|
|
and skip. It has a conv2d layer, which has causal padding in height dimension
|
|
and same paddign in width dimension.
|
|
"""
|
|
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
|
super(ResidualBlock, self).__init__()
|
|
# input conv
|
|
std = math.sqrt(1 / channels * np.prod(kernel_size))
|
|
init = I.Uniform(-std, std)
|
|
receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)]
|
|
rh, rw = receptive_field
|
|
paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same
|
|
conv = nn.Conv2D(channels, 2 * channels, kernel_size,
|
|
padding=paddings,
|
|
dilation=dilations,
|
|
weight_attr=init,
|
|
bias_attr=init)
|
|
self.conv = nn.utils.weight_norm(conv)
|
|
|
|
# condition projection
|
|
std = math.sqrt(1 / cond_channels)
|
|
init = I.Uniform(-std, std)
|
|
condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1),
|
|
weight_attr=init, bias_attr=init)
|
|
self.condition_proj = nn.utils.weight_norm(condition_proj)
|
|
|
|
# parametric residual & skip connection
|
|
std = math.sqrt(1 / channels)
|
|
init = I.Uniform(-std, std)
|
|
out_proj = nn.Conv2D(channels, 2 * channels, (1, 1),
|
|
weight_attr=init, bias_attr=init)
|
|
self.out_proj = nn.utils.weight_norm(out_proj)
|
|
|
|
def forward(self, x, condition):
|
|
x = self.conv(x)
|
|
x += self.condition_proj(condition)
|
|
|
|
content, gate = paddle.chunk(x, 2, axis=1)
|
|
x = paddle.tanh(content) * F.sigmoid(gate)
|
|
|
|
x = self.out_proj(x)
|
|
res, skip = paddle.chunk(x, 2, axis=1)
|
|
return res, skip
|
|
|
|
|
|
class ResidualNet(nn.LayerList):
|
|
"""
|
|
A stack of several ResidualBlocks. It merges condition at each layer. All
|
|
skip outputs are collected.
|
|
"""
|
|
def __init__(self, n_layer, residual_channels, condition_channels, kernel_size, dilations_h):
|
|
if len(dilations_h) != n_layer:
|
|
raise ValueError("number of dilations_h should equals num of layers")
|
|
super(ResidualNet, self).__init__()
|
|
for i in range(n_layer):
|
|
dilation = (dilations_h[i], 2 ** i)
|
|
layer = ResidualBlock(residual_channels, condition_channels, kernel_size, dilation)
|
|
self.append(layer)
|
|
|
|
def forward(self, x, condition):
|
|
skip_connections = []
|
|
for layer in self:
|
|
x, skip = layer(x, condition)
|
|
skip_connections.append(skip)
|
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
|
return out
|
|
|
|
|
|
class Flow(nn.Layer):
|
|
"""
|
|
A Layer that merges the condition and predict scale and bias given a random
|
|
variable X.
|
|
"""
|
|
dilations_dict = {
|
|
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
|
16: [1, 1, 1, 1, 1, 1, 1, 1],
|
|
32: [1, 2, 4, 1, 2, 4, 1, 2],
|
|
64: [1, 2, 4, 8, 16, 1, 2, 4],
|
|
128: [1, 2, 4, 8, 16, 32, 64, 1]
|
|
}
|
|
|
|
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
|
super(Flow, self).__init__()
|
|
# input projection
|
|
self.first_conv = nn.utils.weight_norm(
|
|
nn.Conv2D(1, channels, (1, 1),
|
|
weight_attr=I.Uniform(-1., 1.),
|
|
bias_attr=I.Uniform(-1., 1.)))
|
|
|
|
# residual net
|
|
self.resnet = ResidualNet(n_layers, channels, mel_bands, kernel_size,
|
|
self.dilations_dict[n_group])
|
|
|
|
# output projection
|
|
self.last_conv = nn.utils.weight_norm(
|
|
nn.Conv2D(channels, 2, (1, 1),
|
|
weight_attr=I.Constant(0.),
|
|
bias_attr=I.Constant(0.)))
|
|
|
|
def forward(self, x, condition):
|
|
# TODO(chenfeiyu): it is better to implement the transformation here
|
|
return self.last_conv(self.resnet(self.first_conv(x), condition))
|
|
|
|
|
|
class WaveFlow(nn.LayerList):
|
|
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
|
if n_group % 2 or n_flows % 2:
|
|
raise ValueError("number of flows and number of group must be even "
|
|
"since a permutation along group among flows is used.")
|
|
super(WaveFlow, self).__init__()
|
|
for _ in range(n_flows):
|
|
self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group))
|
|
|
|
# permutations in h
|
|
self.perms = self._create_perm(n_group, n_flows)
|
|
|
|
# specs
|
|
self.n_group = n_group
|
|
self.n_flows = n_flows
|
|
|
|
def _create_perm(self, n_group, n_flows):
|
|
indices = list(range(n_group))
|
|
half = n_group // 2
|
|
perms = []
|
|
for i in range(n_flows):
|
|
if i < n_flows // 2:
|
|
perms.append(indices[::-1])
|
|
else:
|
|
perm = list(reversed(indices[:half])) + list(reversed(indices[half:]))
|
|
perms.append(perm)
|
|
return perms
|
|
|
|
def _trim(self, x, condition):
|
|
assert condition.shape[-1] >= x.shape[-1]
|
|
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
|
|
|
|
if x.shape[-1] > pruned_len:
|
|
x = x[:, :pruned_len]
|
|
if condition.shape[-1] > pruned_len:
|
|
condition = condition[:, :, :pruned_len]
|
|
return x, condition
|
|
|
|
def forward(self, x, condition):
|
|
# x: (B, T)
|
|
# condition: (B, C, T) upsampled condition
|
|
x, condition = self._trim(x, condition)
|
|
|
|
# to (B, C, h, T //h) layout
|
|
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
|
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
|
|
|
# flows
|
|
logs_list = []
|
|
for i, layer in enumerate(self):
|
|
# shiting: z[i, j] depends only on x[<i, :]
|
|
input = x[:, :, :-1, :]
|
|
cond = condition[:, :, 1:, :]
|
|
output = layer(input, cond)
|
|
logs, b = paddle.chunk(output, 2, axis=1)
|
|
logs_list.append(logs)
|
|
# the transformation
|
|
x_0 = x[:, :, :1, :] # the first row, just copy it
|
|
x_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
|
x = paddle.concat([x_0, x_out], axis=2)
|
|
|
|
# permute paddle has no shuffle dim
|
|
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
|
|
|
z = paddle.squeeze(x, 1)
|
|
return z, logs_list
|
|
|
|
|
|
# TODO(chenfeiyu): WaveFlowLoss |