WIP: add sample code for parallel wavegan

This commit is contained in:
chenfeiyu 2021-06-11 02:43:10 +08:00
parent 258083aea9
commit 66062d29e5
1 changed files with 300 additions and 60 deletions

View File

@ -13,33 +13,88 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Dict, Any, Union, Optional
import numpy as np import numpy as np
import paddle import paddle
from paddle import Tensor
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
class Stretch2D(nn.Layer): class Stretch2D(nn.Layer):
def __init__(self, x_scale, y_scale, mode="nearest"): def __init__(self, w_scale: int, h_scale: int, mode: str="nearest"):
"""Strech an image (or image-like object) with some interpolation.
Parameters
----------
w_scale : int
Scalar of width.
h_scale : int
Scalar of the height.
mode : str, optional
Interpolation mode, modes suppored are "nearest", "bilinear",
"trilinear", "bicubic", "linear" and "area",by default "nearest"
For more details about interpolation, see
`paddle.nn.functional.interpolate <https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/functional/interpolate_en.html>`_.
"""
super().__init__() super().__init__()
self.x_scale = x_scale self.w_scale = w_scale
self.y_scale = y_scale self.h_scale = h_scale
self.mode = mode self.mode = mode
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
"""
Parameters
----------
x : Tensor
Shape (N, C, H, W)
Returns
-------
Tensor
Shape (N, C, H', W'), where ``H'=h_scale * H``, ``W'=w_scale * W``.
The stretched image.
"""
out = F.interpolate( out = F.interpolate(
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) x, scale_factor=(self.h_scale, self.w_scale), mode=self.mode)
return out return out
class UpsampleNet(nn.Layer): class UpsampleNet(nn.Layer):
"""A Layer to upsample spectrogram by applying consecutive stretch and
convolutions.
Parameters
----------
upsample_scales : List[int]
Upsampling factors for each strech.
nonlinear_activation : Optional[str], optional
Activation after each convolution, by default None
nonlinear_activation_params : Dict[str, Any], optional
Parameters passed to construct the activation, by default {}
interpolate_mode : str, optional
Interpolation mode of the strech, by default "nearest"
freq_axis_kernel_size : int, optional
Convolution kernel size along the frequency axis, by default 1
use_causal_conv : bool, optional
Whether to use causal padding before convolution, by default False
If True, Causal padding is used along the time axis, i.e. padding
amount is ``receptive field - 1`` and 0 for before and after,
respectively.
If False, "same" padding is used along the time axis.
"""
def __init__(self, def __init__(self,
upsample_scales, upsample_scales: List[int],
nonlinear_activation=None, nonlinear_activation: Optional[str]=None,
nonlinear_activation_params={}, nonlinear_activation_params: Dict[str, Any]={},
interpolate_mode="nearest", interpolate_mode: str="nearest",
freq_axis_kernel_size=1, freq_axis_kernel_size: int=1,
use_causal_conv=False): use_causal_conv: bool=False):
super().__init__() super().__init__()
self.use_causal_conv = use_causal_conv self.use_causal_conv = use_causal_conv
self.up_layers = nn.LayerList() self.up_layers = nn.LayerList()
@ -59,7 +114,19 @@ class UpsampleNet(nn.Layer):
nn, nonlinear_activation)(**nonlinear_activation_params) nn, nonlinear_activation)(**nonlinear_activation_params)
self.up_layers.extend([stretch, conv, nonlinear]) self.up_layers.extend([stretch, conv, nonlinear])
def forward(self, c): def forward(self, c: Tensor) -> Tensor:
"""
Parameters
----------
c : Tensor
Shape (N, F, T), spectrogram
Returns
-------
Tensor
Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled
spectrogram
"""
c = c.unsqueeze(1) c = c.unsqueeze(1)
for f in self.up_layers: for f in self.up_layers:
if self.use_causal_conv and isinstance(f, nn.Conv2D): if self.use_causal_conv and isinstance(f, nn.Conv2D):
@ -70,15 +137,48 @@ class UpsampleNet(nn.Layer):
class ConvInUpsampleNet(nn.Layer): class ConvInUpsampleNet(nn.Layer):
"""A Layer to upsample spectrogram composed of a convolution and an
UpsampleNet.
Parameters
----------
upsample_scales : List[int]
Upsampling factors for each strech.
nonlinear_activation : Optional[str], optional
Activation after each convolution, by default None
nonlinear_activation_params : Dict[str, Any], optional
Parameters passed to construct the activation, by default {}
interpolate_mode : str, optional
Interpolation mode of the strech, by default "nearest"
freq_axis_kernel_size : int, optional
Convolution kernel size along the frequency axis, by default 1
aux_channels : int, optional
Feature size of the input, by default 80
aux_context_window : int, optional
Context window of the first 1D convolution applied to the input. It
related to the kernel size of the convolution, by default 0
If use causal convolution, the kernel size is ``window + 1``, else
the kernel size is ``2 * window + 1``.
use_causal_conv : bool, optional
Whether to use causal padding before convolution, by default False
If True, Causal padding is used along the time axis, i.e. padding
amount is ``receptive field - 1`` and 0 for before and after,
respectively.
If False, "same" padding is used along the time axis.
"""
def __init__(self, def __init__(self,
upsample_scales, upsample_scales: List[int],
nonlinear_activation=None, nonlinear_activation: Optional[str]=None,
nonlinear_activation_params={}, nonlinear_activation_params: Dict[str, Any]={},
interpolate_mode="nearest", interpolate_mode: str="nearest",
freq_axis_kernel_size=1, freq_axis_kernel_size: int=1,
aux_channels=80, aux_channels: int=80,
aux_context_window=0, aux_context_window: int=0,
use_causal_conv=False): use_causal_conv: bool=False):
super().__init__() super().__init__()
self.aux_context_window = aux_context_window self.aux_context_window = aux_context_window
self.use_causal_conv = use_causal_conv and aux_context_window > 0 self.use_causal_conv = use_causal_conv and aux_context_window > 0
@ -96,23 +196,61 @@ class ConvInUpsampleNet(nn.Layer):
freq_axis_kernel_size=freq_axis_kernel_size, freq_axis_kernel_size=freq_axis_kernel_size,
use_causal_conv=use_causal_conv) use_causal_conv=use_causal_conv)
def forward(self, c): def forward(self, c: Tensor) -> Tensor:
"""
Parameters
----------
c : Tensor
Shape (N, F, T), spectrogram
Returns
-------
Tensors
Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled
spectrogram
"""
c_ = self.conv_in(c) c_ = self.conv_in(c)
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
return self.upsample(c) return self.upsample(c)
class ResidualBlock(nn.Layer): class ResidualBlock(nn.Layer):
"""A gated activation unit composed of an 1D convolution, a gated tanh
unit and parametric redidual and skip connections. For more details,
refer to `WaveNet: A Generative Model for Raw Audio <https://arxiv.org/abs/1609.03499>`_.
Parameters
----------
kernel_size : int, optional
Kernel size of the 1D convolution, by default 3
residual_channels : int, optional
Feature size of the resiaudl output(and also the input), by default 64
gate_channels : int, optional
Output feature size of the 1D convolution, by default 128
skip_channels : int, optional
Feature size of the skip output, by default 64
aux_channels : int, optional
Feature size of the auxiliary input (e.g. spectrogram), by default 80
dropout : float, optional
Probability of the dropout before the 1D convolution, by default 0.
dilation : int, optional
Dilation of the 1D convolution, by default 1
bias : bool, optional
Whether to use bias in the 1D convolution, by default True
use_causal_conv : bool, optional
Whether to use causal padding for the 1D convolution, by default False
"""
def __init__(self, def __init__(self,
kernel_size=3, kernel_size: int=3,
residual_channels=64, residual_channels: int=64,
gate_channels=128, gate_channels: int=128,
skip_channels=64, skip_channels: int=64,
aux_channels=80, aux_channels: int=80,
dropout=0., dropout: float=0.,
dilation=1, dilation: int=1,
bias=True, bias: bool=True,
use_causal_conv=False): use_causal_conv: bool=False):
super().__init__() super().__init__()
self.dropout = dropout self.dropout = dropout
if use_causal_conv: if use_causal_conv:
@ -144,7 +282,24 @@ class ResidualBlock(nn.Layer):
self.conv1x1_skip = nn.Conv1D( self.conv1x1_skip = nn.Conv1D(
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias) gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
def forward(self, x, c): def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]:
"""
Parameters
----------
x : Tensor
Shape (N, C_res, T), the input features.
c : Tensor
Shape (N, C_aux, T), he auxiliary input.
Returns
-------
res : Tensor
Shape (N, C_res, T), the residual output, which is used as the
input of the next ResidualBlock in a stack of ResidualBlocks.
skip : Tensor
Shape (N, C_skip, T), the skip output, which is collected among
each layer in a stack of ResidualBlocks.
"""
x_input = x x_input = x
x = F.dropout(x, self.dropout, training=self.training) x = F.dropout(x, self.dropout, training=self.training)
x = self.conv(x) x = self.conv(x)
@ -162,26 +317,76 @@ class ResidualBlock(nn.Layer):
class PWGGenerator(nn.Layer): class PWGGenerator(nn.Layer):
"""Wave Generator for Parallel WaveGAN
Parameters
----------
in_channels : int, optional
Number of channels of the input waveform, by default 1
out_channels : int, optional
Number of channels of the output waveform, by default 1
kernel_size : int, optional
Kernel size of the residual blocks inside, by default 3
layers : int, optional
Number of residual blocks inside, by default 30
stacks : int, optional
The number of groups to split the residual blocks into, by default 3
Within each group, the dilation of the residual block grows
exponentially.
residual_channels : int, optional
Residual channel of the residual blocks, by default 64
gate_channels : int, optional
Gate channel of the residual blocks, by default 128
skip_channels : int, optional
Skip channel of the residual blocks, by default 64
aux_channels : int, optional
Auxiliary channel of the residual blocks, by default 80
aux_context_window : int, optional
The context window size of the first convolution applied to the
auxiliary input, by default 2
dropout : float, optional
Dropout of the residual blocks, by default 0.
bias : bool, optional
Whether to use bias in residual blocks, by default True
use_weight_norm : bool, optional
Whether to use weight norm in all convolutions, by default True
use_causal_conv : bool, optional
Whether to use causal padding in the upsample network and residual
blocks, by default False
upsample_scales : List[int], optional
Upsample scales of the upsample network, by default [4, 4, 4, 4]
nonlinear_activation : Optional[str], optional
Non linear activation in upsample network, by default None
nonlinear_activation_params : Dict[str, Any], optional
Parameters passed to the linear activation in the upsample network,
by default {}
interpolate_mode : str, optional
Interpolation mode of the upsample network, by default "nearest"
freq_axis_kernel_size : int, optional
Kernel size along the frequency axis of the upsample network, by default 1
"""
def __init__(self, def __init__(self,
in_channels=1, in_channels: int=1,
out_channels=1, out_channels: int=1,
kernel_size=3, kernel_size: int=3,
layers=30, layers: int=30,
stacks=3, stacks: int=3,
residual_channels=64, residual_channels: int=64,
gate_channels=128, gate_channels: int=128,
skip_channels=64, skip_channels: int=64,
aux_channels=80, aux_channels: int=80,
aux_context_window=2, aux_context_window: int=2,
dropout=0., dropout: float=0.,
bias=True, bias: bool=True,
use_weight_norm=True, use_weight_norm: bool=True,
use_causal_conv=False, use_causal_conv: bool=False,
upsample_scales=[4, 4, 4, 4], upsample_scales: List[int]=[4, 4, 4, 4],
nonlinear_activation=None, nonlinear_activation: Optional[str]=None,
nonlinear_activation_params={}, nonlinear_activation_params: Dict[str, Any]={},
interpolate_mode="nearest", interpolate_mode: str="nearest",
freq_axis_kernel_size=1): freq_axis_kernel_size: int=1):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -233,10 +438,24 @@ class PWGGenerator(nn.Layer):
if use_weight_norm: if use_weight_norm:
self.apply_weight_norm() self.apply_weight_norm()
def forward(self, x, c): def forward(self, x: Tensor, c: Tensor) -> Tensor:
if c is not None: """Generate waveform.
c = self.upsample_net(c)
assert c.shape[-1] == x.shape[-1] Parameters
----------
x : Tensor
Shape (N, C_in, T), The input waveform.
c : Tensor
Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram). It
is upsampled to match the time resolution of the input.
Returns
-------
Tensor
Shape (N, C_out, T), the generated waveform.
"""
c = self.upsample_net(c)
assert c.shape[-1] == x.shape[-1]
x = self.first_conv(x) x = self.first_conv(x)
skips = 0 skips = 0
@ -249,6 +468,10 @@ class PWGGenerator(nn.Layer):
return x return x
def apply_weight_norm(self): def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer): def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)): if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer) nn.utils.weight_norm(layer)
@ -256,6 +479,10 @@ class PWGGenerator(nn.Layer):
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)
def remove_weight_norm(self): def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer): def _remove_weight_norm(layer):
try: try:
nn.utils.remove_weight_norm(layer) nn.utils.remove_weight_norm(layer)
@ -264,17 +491,30 @@ class PWGGenerator(nn.Layer):
self.apply(_remove_weight_norm) self.apply(_remove_weight_norm)
def inference(self, c=None, x=None): def inference(self, c: Optional[Tensor]=None,
""" x: Optional[Tensor]=None) -> Tensor:
single instance inference """Waveform generation. This function is used for single instance
c: [T', C] condition inference.
x: [T, 1] noise
Parameters
----------
c : Tensor, optional
Shape (T', C_aux), the auxiliary input, by default None
x : Tensor, optional
Shape (T, C_in), the noise waveform, by default None
If not provided, a sample is drawn from a gaussian distribution.
Returns
-------
Tensor
Shape (T, C_out), the generated waveform
""" """
if x is not None: if x is not None:
x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch
else: else:
assert c is not None assert c is not None
x = paddle.randn([1, 1, c.shape[0] * self.upsample_factor]) x = paddle.randn(
[1, self.in_channels, c.shape[0] * self.upsample_factor])
if c is not None: if c is not None:
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch