From 66062d29e585426aeadbc208f0b65f9017af4539 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Fri, 11 Jun 2021 02:43:10 +0800 Subject: [PATCH] WIP: add sample code for parallel wavegan --- parakeet/models/parallel_wavegan.py | 360 +++++++++++++++++++++++----- 1 file changed, 300 insertions(+), 60 deletions(-) diff --git a/parakeet/models/parallel_wavegan.py b/parakeet/models/parallel_wavegan.py index 3ed6058..aded823 100644 --- a/parakeet/models/parallel_wavegan.py +++ b/parakeet/models/parallel_wavegan.py @@ -13,33 +13,88 @@ # limitations under the License. import math +from typing import List, Dict, Any, Union, Optional + import numpy as np import paddle +from paddle import Tensor from paddle import nn from paddle.nn import functional as F class Stretch2D(nn.Layer): - def __init__(self, x_scale, y_scale, mode="nearest"): + def __init__(self, w_scale: int, h_scale: int, mode: str="nearest"): + """Strech an image (or image-like object) with some interpolation. + + Parameters + ---------- + w_scale : int + Scalar of width. + h_scale : int + Scalar of the height. + mode : str, optional + Interpolation mode, modes suppored are "nearest", "bilinear", + "trilinear", "bicubic", "linear" and "area",by default "nearest" + + For more details about interpolation, see + `paddle.nn.functional.interpolate `_. + """ super().__init__() - self.x_scale = x_scale - self.y_scale = y_scale + self.w_scale = w_scale + self.h_scale = h_scale self.mode = mode - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: + """ + Parameters + ---------- + x : Tensor + Shape (N, C, H, W) + + Returns + ------- + Tensor + Shape (N, C, H', W'), where ``H'=h_scale * H``, ``W'=w_scale * W``. + The stretched image. + """ out = F.interpolate( - x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + x, scale_factor=(self.h_scale, self.w_scale), mode=self.mode) return out class UpsampleNet(nn.Layer): + """A Layer to upsample spectrogram by applying consecutive stretch and + convolutions. + + Parameters + ---------- + upsample_scales : List[int] + Upsampling factors for each strech. + nonlinear_activation : Optional[str], optional + Activation after each convolution, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to construct the activation, by default {} + interpolate_mode : str, optional + Interpolation mode of the strech, by default "nearest" + freq_axis_kernel_size : int, optional + Convolution kernel size along the frequency axis, by default 1 + use_causal_conv : bool, optional + Whether to use causal padding before convolution, by default False + + If True, Causal padding is used along the time axis, i.e. padding + amount is ``receptive field - 1`` and 0 for before and after, + respectively. + + If False, "same" padding is used along the time axis. + """ + def __init__(self, - upsample_scales, - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1, - use_causal_conv=False): + upsample_scales: List[int], + nonlinear_activation: Optional[str]=None, + nonlinear_activation_params: Dict[str, Any]={}, + interpolate_mode: str="nearest", + freq_axis_kernel_size: int=1, + use_causal_conv: bool=False): super().__init__() self.use_causal_conv = use_causal_conv self.up_layers = nn.LayerList() @@ -59,7 +114,19 @@ class UpsampleNet(nn.Layer): nn, nonlinear_activation)(**nonlinear_activation_params) self.up_layers.extend([stretch, conv, nonlinear]) - def forward(self, c): + def forward(self, c: Tensor) -> Tensor: + """ + Parameters + ---------- + c : Tensor + Shape (N, F, T), spectrogram + + Returns + ------- + Tensor + Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled + spectrogram + """ c = c.unsqueeze(1) for f in self.up_layers: if self.use_causal_conv and isinstance(f, nn.Conv2D): @@ -70,15 +137,48 @@ class UpsampleNet(nn.Layer): class ConvInUpsampleNet(nn.Layer): + """A Layer to upsample spectrogram composed of a convolution and an + UpsampleNet. + + Parameters + ---------- + upsample_scales : List[int] + Upsampling factors for each strech. + nonlinear_activation : Optional[str], optional + Activation after each convolution, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to construct the activation, by default {} + interpolate_mode : str, optional + Interpolation mode of the strech, by default "nearest" + freq_axis_kernel_size : int, optional + Convolution kernel size along the frequency axis, by default 1 + aux_channels : int, optional + Feature size of the input, by default 80 + aux_context_window : int, optional + Context window of the first 1D convolution applied to the input. It + related to the kernel size of the convolution, by default 0 + + If use causal convolution, the kernel size is ``window + 1``, else + the kernel size is ``2 * window + 1``. + use_causal_conv : bool, optional + Whether to use causal padding before convolution, by default False + + If True, Causal padding is used along the time axis, i.e. padding + amount is ``receptive field - 1`` and 0 for before and after, + respectively. + + If False, "same" padding is used along the time axis. + """ + def __init__(self, - upsample_scales, - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1, - aux_channels=80, - aux_context_window=0, - use_causal_conv=False): + upsample_scales: List[int], + nonlinear_activation: Optional[str]=None, + nonlinear_activation_params: Dict[str, Any]={}, + interpolate_mode: str="nearest", + freq_axis_kernel_size: int=1, + aux_channels: int=80, + aux_context_window: int=0, + use_causal_conv: bool=False): super().__init__() self.aux_context_window = aux_context_window self.use_causal_conv = use_causal_conv and aux_context_window > 0 @@ -96,23 +196,61 @@ class ConvInUpsampleNet(nn.Layer): freq_axis_kernel_size=freq_axis_kernel_size, use_causal_conv=use_causal_conv) - def forward(self, c): + def forward(self, c: Tensor) -> Tensor: + """ + Parameters + ---------- + c : Tensor + Shape (N, F, T), spectrogram + + Returns + ------- + Tensors + Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled + spectrogram + """ c_ = self.conv_in(c) c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ return self.upsample(c) class ResidualBlock(nn.Layer): + """A gated activation unit composed of an 1D convolution, a gated tanh + unit and parametric redidual and skip connections. For more details, + refer to `WaveNet: A Generative Model for Raw Audio `_. + + Parameters + ---------- + kernel_size : int, optional + Kernel size of the 1D convolution, by default 3 + residual_channels : int, optional + Feature size of the resiaudl output(and also the input), by default 64 + gate_channels : int, optional + Output feature size of the 1D convolution, by default 128 + skip_channels : int, optional + Feature size of the skip output, by default 64 + aux_channels : int, optional + Feature size of the auxiliary input (e.g. spectrogram), by default 80 + dropout : float, optional + Probability of the dropout before the 1D convolution, by default 0. + dilation : int, optional + Dilation of the 1D convolution, by default 1 + bias : bool, optional + Whether to use bias in the 1D convolution, by default True + use_causal_conv : bool, optional + Whether to use causal padding for the 1D convolution, by default False + """ + def __init__(self, - kernel_size=3, - residual_channels=64, - gate_channels=128, - skip_channels=64, - aux_channels=80, - dropout=0., - dilation=1, - bias=True, - use_causal_conv=False): + kernel_size: int=3, + residual_channels: int=64, + gate_channels: int=128, + skip_channels: int=64, + aux_channels: int=80, + dropout: float=0., + dilation: int=1, + bias: bool=True, + use_causal_conv: bool=False): super().__init__() self.dropout = dropout if use_causal_conv: @@ -144,7 +282,24 @@ class ResidualBlock(nn.Layer): self.conv1x1_skip = nn.Conv1D( gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias) - def forward(self, x, c): + def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]: + """ + Parameters + ---------- + x : Tensor + Shape (N, C_res, T), the input features. + c : Tensor + Shape (N, C_aux, T), he auxiliary input. + + Returns + ------- + res : Tensor + Shape (N, C_res, T), the residual output, which is used as the + input of the next ResidualBlock in a stack of ResidualBlocks. + skip : Tensor + Shape (N, C_skip, T), the skip output, which is collected among + each layer in a stack of ResidualBlocks. + """ x_input = x x = F.dropout(x, self.dropout, training=self.training) x = self.conv(x) @@ -162,26 +317,76 @@ class ResidualBlock(nn.Layer): class PWGGenerator(nn.Layer): + """Wave Generator for Parallel WaveGAN + + Parameters + ---------- + in_channels : int, optional + Number of channels of the input waveform, by default 1 + out_channels : int, optional + Number of channels of the output waveform, by default 1 + kernel_size : int, optional + Kernel size of the residual blocks inside, by default 3 + layers : int, optional + Number of residual blocks inside, by default 30 + stacks : int, optional + The number of groups to split the residual blocks into, by default 3 + + Within each group, the dilation of the residual block grows + exponentially. + residual_channels : int, optional + Residual channel of the residual blocks, by default 64 + gate_channels : int, optional + Gate channel of the residual blocks, by default 128 + skip_channels : int, optional + Skip channel of the residual blocks, by default 64 + aux_channels : int, optional + Auxiliary channel of the residual blocks, by default 80 + aux_context_window : int, optional + The context window size of the first convolution applied to the + auxiliary input, by default 2 + dropout : float, optional + Dropout of the residual blocks, by default 0. + bias : bool, optional + Whether to use bias in residual blocks, by default True + use_weight_norm : bool, optional + Whether to use weight norm in all convolutions, by default True + use_causal_conv : bool, optional + Whether to use causal padding in the upsample network and residual + blocks, by default False + upsample_scales : List[int], optional + Upsample scales of the upsample network, by default [4, 4, 4, 4] + nonlinear_activation : Optional[str], optional + Non linear activation in upsample network, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to the linear activation in the upsample network, + by default {} + interpolate_mode : str, optional + Interpolation mode of the upsample network, by default "nearest" + freq_axis_kernel_size : int, optional + Kernel size along the frequency axis of the upsample network, by default 1 + """ + def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - layers=30, - stacks=3, - residual_channels=64, - gate_channels=128, - skip_channels=64, - aux_channels=80, - aux_context_window=2, - dropout=0., - bias=True, - use_weight_norm=True, - use_causal_conv=False, - upsample_scales=[4, 4, 4, 4], - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1): + in_channels: int=1, + out_channels: int=1, + kernel_size: int=3, + layers: int=30, + stacks: int=3, + residual_channels: int=64, + gate_channels: int=128, + skip_channels: int=64, + aux_channels: int=80, + aux_context_window: int=2, + dropout: float=0., + bias: bool=True, + use_weight_norm: bool=True, + use_causal_conv: bool=False, + upsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: Optional[str]=None, + nonlinear_activation_params: Dict[str, Any]={}, + interpolate_mode: str="nearest", + freq_axis_kernel_size: int=1): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -233,10 +438,24 @@ class PWGGenerator(nn.Layer): if use_weight_norm: self.apply_weight_norm() - def forward(self, x, c): - if c is not None: - c = self.upsample_net(c) - assert c.shape[-1] == x.shape[-1] + def forward(self, x: Tensor, c: Tensor) -> Tensor: + """Generate waveform. + + Parameters + ---------- + x : Tensor + Shape (N, C_in, T), The input waveform. + c : Tensor + Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram). It + is upsampled to match the time resolution of the input. + + Returns + ------- + Tensor + Shape (N, C_out, T), the generated waveform. + """ + c = self.upsample_net(c) + assert c.shape[-1] == x.shape[-1] x = self.first_conv(x) skips = 0 @@ -249,6 +468,10 @@ class PWGGenerator(nn.Layer): return x def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + def _apply_weight_norm(layer): if isinstance(layer, (nn.Conv1D, nn.Conv2D)): nn.utils.weight_norm(layer) @@ -256,6 +479,10 @@ class PWGGenerator(nn.Layer): self.apply(_apply_weight_norm) def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + def _remove_weight_norm(layer): try: nn.utils.remove_weight_norm(layer) @@ -264,17 +491,30 @@ class PWGGenerator(nn.Layer): self.apply(_remove_weight_norm) - def inference(self, c=None, x=None): - """ - single instance inference - c: [T', C] condition - x: [T, 1] noise + def inference(self, c: Optional[Tensor]=None, + x: Optional[Tensor]=None) -> Tensor: + """Waveform generation. This function is used for single instance + inference. + + Parameters + ---------- + c : Tensor, optional + Shape (T', C_aux), the auxiliary input, by default None + x : Tensor, optional + Shape (T, C_in), the noise waveform, by default None + If not provided, a sample is drawn from a gaussian distribution. + + Returns + ------- + Tensor + Shape (T, C_out), the generated waveform """ if x is not None: x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch else: assert c is not None - x = paddle.randn([1, 1, c.shape[0] * self.upsample_factor]) + x = paddle.randn( + [1, self.in_channels, c.shape[0] * self.upsample_factor]) if c is not None: c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch