import math import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I from typing import Sequence from parakeet.modules import geometry as geo import itertools import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid __all__ = ["WaveFlow"] def fold(x, n_group): """Fold audio or spectrogram's temporal dimension in to groups. Args: x (Tensor): shape(*, time_steps), the input tensor n_group (int): the size of a group. Returns: Tensor: shape(*, time_steps // n_group, group), folded tensor. """ *spatial_shape, time_steps = x.shape new_shape = spatial_shape + [time_steps // n_group, n_group] return paddle.reshape(x, new_shape) class UpsampleNet(nn.LayerList): """ Layer to upsample mel spectrogram to the same temporal resolution with the corresponding waveform. It consists of several conv2dtranspose layers which perform de convolution on mel and time dimension. """ def __init__(self, upsample_factors: Sequence[int]): super(UpsampleNet, self).__init__() for factor in upsample_factors: std = math.sqrt(1 / (3 * 2 * factor)) init = I.Uniform(-std, std) self.append( nn.utils.weight_norm( nn.Conv2DTranspose(1, 1, (3, 2 * factor), padding=(1, factor // 2), stride=(1, factor), weight_attr=init, bias_attr=init))) # upsample factors self.upsample_factor = np.prod(upsample_factors) self.upsample_factors = upsample_factors def forward(self, x, trim_conv_artifact=False): """ Args: x (Tensor): shape(batch_size, input_channels, time_steps), the input spectrogram. trim_conv_artifact (bool, optional): trim deconvolution artifact at each layer. Defaults to False. Returns: Tensor: shape(batch_size, input_channels, time_steps * upsample_factors). If trim_conv_artifact is True, the output time steps is less than time_steps * upsample_factors. """ x = paddle.unsqueeze(x, 1) for layer in self: x = layer(x) if trim_conv_artifact: time_cutoff = layer._kernel_size[1] - layer._stride[1] x = x[:, :, :, -time_cutoff] x = F.leaky_relu(x, 0.4) x = paddle.squeeze(x, 1) return x class ResidualBlock(nn.Layer): """ ResidualBlock that merges infomation from the condition and outputs residual and skip. It has a conv2d layer, which has causal padding in height dimension and same paddign in width dimension. """ def __init__(self, channels, cond_channels, kernel_size, dilations): super(ResidualBlock, self).__init__() # input conv std = math.sqrt(1 / channels * np.prod(kernel_size)) init = I.Uniform(-std, std) receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)] rh, rw = receptive_field paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same conv = nn.Conv2D(channels, 2 * channels, kernel_size, padding=paddings, dilation=dilations, weight_attr=init, bias_attr=init) self.conv = nn.utils.weight_norm(conv) # condition projection std = math.sqrt(1 / cond_channels) init = I.Uniform(-std, std) condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1), weight_attr=init, bias_attr=init) self.condition_proj = nn.utils.weight_norm(condition_proj) # parametric residual & skip connection std = math.sqrt(1 / channels) init = I.Uniform(-std, std) out_proj = nn.Conv2D(channels, 2 * channels, (1, 1), weight_attr=init, bias_attr=init) self.out_proj = nn.utils.weight_norm(out_proj) def forward(self, x, condition): x = self.conv(x) x += self.condition_proj(condition) content, gate = paddle.chunk(x, 2, axis=1) x = paddle.tanh(content) * F.sigmoid(gate) x = self.out_proj(x) res, skip = paddle.chunk(x, 2, axis=1) return res, skip class ResidualNet(nn.LayerList): """ A stack of several ResidualBlocks. It merges condition at each layer. All skip outputs are collected. """ def __init__(self, n_layer, residual_channels, condition_channels, kernel_size, dilations_h): if len(dilations_h) != n_layer: raise ValueError("number of dilations_h should equals num of layers") super(ResidualNet, self).__init__() for i in range(n_layer): dilation = (dilations_h[i], 2 ** i) layer = ResidualBlock(residual_channels, condition_channels, kernel_size, dilation) self.append(layer) def forward(self, x, condition): skip_connections = [] for layer in self: x, skip = layer(x, condition) skip_connections.append(skip) out = paddle.sum(paddle.stack(skip_connections, 0), 0) return out class Flow(nn.Layer): """ A Layer that merges the condition and predict scale and bias given a random variable X. """ dilations_dict = { 8: [1, 1, 1, 1, 1, 1, 1, 1], 16: [1, 1, 1, 1, 1, 1, 1, 1], 32: [1, 2, 4, 1, 2, 4, 1, 2], 64: [1, 2, 4, 8, 16, 1, 2, 4], 128: [1, 2, 4, 8, 16, 32, 64, 1] } def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group): super(Flow, self).__init__() # input projection self.first_conv = nn.utils.weight_norm( nn.Conv2D(1, channels, (1, 1), weight_attr=I.Uniform(-1., 1.), bias_attr=I.Uniform(-1., 1.))) # residual net self.resnet = ResidualNet(n_layers, channels, mel_bands, kernel_size, self.dilations_dict[n_group]) # output projection self.last_conv = nn.utils.weight_norm( nn.Conv2D(channels, 2, (1, 1), weight_attr=I.Constant(0.), bias_attr=I.Constant(0.))) def forward(self, x, condition): # TODO(chenfeiyu): it is better to implement the transformation here return self.last_conv(self.resnet(self.first_conv(x), condition)) class WaveFlow(nn.LayerList): def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size): if n_group % 2 or n_flows % 2: raise ValueError("number of flows and number of group must be even " "since a permutation along group among flows is used.") super(WaveFlow, self).__init__() for _ in range(n_flows): self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group)) # permutations in h self.perms = self._create_perm(n_group, n_flows) # specs self.n_group = n_group self.n_flows = n_flows def _create_perm(self, n_group, n_flows): indices = list(range(n_group)) half = n_group // 2 perms = [] for i in range(n_flows): if i < n_flows // 2: perms.append(indices[::-1]) else: perm = list(reversed(indices[:half])) + list(reversed(indices[half:])) perms.append(perm) return perms def _trim(self, x, condition): assert condition.shape[-1] >= x.shape[-1] pruned_len = int(x.shape[-1] // self.n_group * self.n_group) if x.shape[-1] > pruned_len: x = x[:, :pruned_len] if condition.shape[-1] > pruned_len: condition = condition[:, :, :pruned_len] return x, condition def forward(self, x, condition): # x: (B, T) # condition: (B, C, T) upsampled condition x, condition = self._trim(x, condition) # to (B, C, h, T //h) layout x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1) condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2]) # flows logs_list = [] for i, layer in enumerate(self): # shiting: z[i, j] depends only on x[