WIP: add sample code for parallel wavegan
This commit is contained in:
parent
258083aea9
commit
66062d29e5
|
@ -13,33 +13,88 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import List, Dict, Any, Union, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
class Stretch2D(nn.Layer):
|
class Stretch2D(nn.Layer):
|
||||||
def __init__(self, x_scale, y_scale, mode="nearest"):
|
def __init__(self, w_scale: int, h_scale: int, mode: str="nearest"):
|
||||||
|
"""Strech an image (or image-like object) with some interpolation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
w_scale : int
|
||||||
|
Scalar of width.
|
||||||
|
h_scale : int
|
||||||
|
Scalar of the height.
|
||||||
|
mode : str, optional
|
||||||
|
Interpolation mode, modes suppored are "nearest", "bilinear",
|
||||||
|
"trilinear", "bicubic", "linear" and "area",by default "nearest"
|
||||||
|
|
||||||
|
For more details about interpolation, see
|
||||||
|
`paddle.nn.functional.interpolate <https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/functional/interpolate_en.html>`_.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.x_scale = x_scale
|
self.w_scale = w_scale
|
||||||
self.y_scale = y_scale
|
self.h_scale = h_scale
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, C, H, W)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, C, H', W'), where ``H'=h_scale * H``, ``W'=w_scale * W``.
|
||||||
|
The stretched image.
|
||||||
|
"""
|
||||||
out = F.interpolate(
|
out = F.interpolate(
|
||||||
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
x, scale_factor=(self.h_scale, self.w_scale), mode=self.mode)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class UpsampleNet(nn.Layer):
|
class UpsampleNet(nn.Layer):
|
||||||
|
"""A Layer to upsample spectrogram by applying consecutive stretch and
|
||||||
|
convolutions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
upsample_scales : List[int]
|
||||||
|
Upsampling factors for each strech.
|
||||||
|
nonlinear_activation : Optional[str], optional
|
||||||
|
Activation after each convolution, by default None
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters passed to construct the activation, by default {}
|
||||||
|
interpolate_mode : str, optional
|
||||||
|
Interpolation mode of the strech, by default "nearest"
|
||||||
|
freq_axis_kernel_size : int, optional
|
||||||
|
Convolution kernel size along the frequency axis, by default 1
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding before convolution, by default False
|
||||||
|
|
||||||
|
If True, Causal padding is used along the time axis, i.e. padding
|
||||||
|
amount is ``receptive field - 1`` and 0 for before and after,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
If False, "same" padding is used along the time axis.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
upsample_scales,
|
upsample_scales: List[int],
|
||||||
nonlinear_activation=None,
|
nonlinear_activation: Optional[str]=None,
|
||||||
nonlinear_activation_params={},
|
nonlinear_activation_params: Dict[str, Any]={},
|
||||||
interpolate_mode="nearest",
|
interpolate_mode: str="nearest",
|
||||||
freq_axis_kernel_size=1,
|
freq_axis_kernel_size: int=1,
|
||||||
use_causal_conv=False):
|
use_causal_conv: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_causal_conv = use_causal_conv
|
self.use_causal_conv = use_causal_conv
|
||||||
self.up_layers = nn.LayerList()
|
self.up_layers = nn.LayerList()
|
||||||
|
@ -59,7 +114,19 @@ class UpsampleNet(nn.Layer):
|
||||||
nn, nonlinear_activation)(**nonlinear_activation_params)
|
nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||||
self.up_layers.extend([stretch, conv, nonlinear])
|
self.up_layers.extend([stretch, conv, nonlinear])
|
||||||
|
|
||||||
def forward(self, c):
|
def forward(self, c: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, F, T), spectrogram
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled
|
||||||
|
spectrogram
|
||||||
|
"""
|
||||||
c = c.unsqueeze(1)
|
c = c.unsqueeze(1)
|
||||||
for f in self.up_layers:
|
for f in self.up_layers:
|
||||||
if self.use_causal_conv and isinstance(f, nn.Conv2D):
|
if self.use_causal_conv and isinstance(f, nn.Conv2D):
|
||||||
|
@ -70,15 +137,48 @@ class UpsampleNet(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class ConvInUpsampleNet(nn.Layer):
|
class ConvInUpsampleNet(nn.Layer):
|
||||||
|
"""A Layer to upsample spectrogram composed of a convolution and an
|
||||||
|
UpsampleNet.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
upsample_scales : List[int]
|
||||||
|
Upsampling factors for each strech.
|
||||||
|
nonlinear_activation : Optional[str], optional
|
||||||
|
Activation after each convolution, by default None
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters passed to construct the activation, by default {}
|
||||||
|
interpolate_mode : str, optional
|
||||||
|
Interpolation mode of the strech, by default "nearest"
|
||||||
|
freq_axis_kernel_size : int, optional
|
||||||
|
Convolution kernel size along the frequency axis, by default 1
|
||||||
|
aux_channels : int, optional
|
||||||
|
Feature size of the input, by default 80
|
||||||
|
aux_context_window : int, optional
|
||||||
|
Context window of the first 1D convolution applied to the input. It
|
||||||
|
related to the kernel size of the convolution, by default 0
|
||||||
|
|
||||||
|
If use causal convolution, the kernel size is ``window + 1``, else
|
||||||
|
the kernel size is ``2 * window + 1``.
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding before convolution, by default False
|
||||||
|
|
||||||
|
If True, Causal padding is used along the time axis, i.e. padding
|
||||||
|
amount is ``receptive field - 1`` and 0 for before and after,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
If False, "same" padding is used along the time axis.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
upsample_scales,
|
upsample_scales: List[int],
|
||||||
nonlinear_activation=None,
|
nonlinear_activation: Optional[str]=None,
|
||||||
nonlinear_activation_params={},
|
nonlinear_activation_params: Dict[str, Any]={},
|
||||||
interpolate_mode="nearest",
|
interpolate_mode: str="nearest",
|
||||||
freq_axis_kernel_size=1,
|
freq_axis_kernel_size: int=1,
|
||||||
aux_channels=80,
|
aux_channels: int=80,
|
||||||
aux_context_window=0,
|
aux_context_window: int=0,
|
||||||
use_causal_conv=False):
|
use_causal_conv: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.aux_context_window = aux_context_window
|
self.aux_context_window = aux_context_window
|
||||||
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||||||
|
@ -96,23 +196,61 @@ class ConvInUpsampleNet(nn.Layer):
|
||||||
freq_axis_kernel_size=freq_axis_kernel_size,
|
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||||
use_causal_conv=use_causal_conv)
|
use_causal_conv=use_causal_conv)
|
||||||
|
|
||||||
def forward(self, c):
|
def forward(self, c: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, F, T), spectrogram
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensors
|
||||||
|
Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled
|
||||||
|
spectrogram
|
||||||
|
"""
|
||||||
c_ = self.conv_in(c)
|
c_ = self.conv_in(c)
|
||||||
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
|
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
|
||||||
return self.upsample(c)
|
return self.upsample(c)
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Layer):
|
class ResidualBlock(nn.Layer):
|
||||||
|
"""A gated activation unit composed of an 1D convolution, a gated tanh
|
||||||
|
unit and parametric redidual and skip connections. For more details,
|
||||||
|
refer to `WaveNet: A Generative Model for Raw Audio <https://arxiv.org/abs/1609.03499>`_.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
kernel_size : int, optional
|
||||||
|
Kernel size of the 1D convolution, by default 3
|
||||||
|
residual_channels : int, optional
|
||||||
|
Feature size of the resiaudl output(and also the input), by default 64
|
||||||
|
gate_channels : int, optional
|
||||||
|
Output feature size of the 1D convolution, by default 128
|
||||||
|
skip_channels : int, optional
|
||||||
|
Feature size of the skip output, by default 64
|
||||||
|
aux_channels : int, optional
|
||||||
|
Feature size of the auxiliary input (e.g. spectrogram), by default 80
|
||||||
|
dropout : float, optional
|
||||||
|
Probability of the dropout before the 1D convolution, by default 0.
|
||||||
|
dilation : int, optional
|
||||||
|
Dilation of the 1D convolution, by default 1
|
||||||
|
bias : bool, optional
|
||||||
|
Whether to use bias in the 1D convolution, by default True
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding for the 1D convolution, by default False
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
kernel_size=3,
|
kernel_size: int=3,
|
||||||
residual_channels=64,
|
residual_channels: int=64,
|
||||||
gate_channels=128,
|
gate_channels: int=128,
|
||||||
skip_channels=64,
|
skip_channels: int=64,
|
||||||
aux_channels=80,
|
aux_channels: int=80,
|
||||||
dropout=0.,
|
dropout: float=0.,
|
||||||
dilation=1,
|
dilation: int=1,
|
||||||
bias=True,
|
bias: bool=True,
|
||||||
use_causal_conv=False):
|
use_causal_conv: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
if use_causal_conv:
|
if use_causal_conv:
|
||||||
|
@ -144,7 +282,24 @@ class ResidualBlock(nn.Layer):
|
||||||
self.conv1x1_skip = nn.Conv1D(
|
self.conv1x1_skip = nn.Conv1D(
|
||||||
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
|
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
|
||||||
|
|
||||||
def forward(self, x, c):
|
def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, C_res, T), the input features.
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, C_aux, T), he auxiliary input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
res : Tensor
|
||||||
|
Shape (N, C_res, T), the residual output, which is used as the
|
||||||
|
input of the next ResidualBlock in a stack of ResidualBlocks.
|
||||||
|
skip : Tensor
|
||||||
|
Shape (N, C_skip, T), the skip output, which is collected among
|
||||||
|
each layer in a stack of ResidualBlocks.
|
||||||
|
"""
|
||||||
x_input = x
|
x_input = x
|
||||||
x = F.dropout(x, self.dropout, training=self.training)
|
x = F.dropout(x, self.dropout, training=self.training)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
@ -162,26 +317,76 @@ class ResidualBlock(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class PWGGenerator(nn.Layer):
|
class PWGGenerator(nn.Layer):
|
||||||
|
"""Wave Generator for Parallel WaveGAN
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int, optional
|
||||||
|
Number of channels of the input waveform, by default 1
|
||||||
|
out_channels : int, optional
|
||||||
|
Number of channels of the output waveform, by default 1
|
||||||
|
kernel_size : int, optional
|
||||||
|
Kernel size of the residual blocks inside, by default 3
|
||||||
|
layers : int, optional
|
||||||
|
Number of residual blocks inside, by default 30
|
||||||
|
stacks : int, optional
|
||||||
|
The number of groups to split the residual blocks into, by default 3
|
||||||
|
|
||||||
|
Within each group, the dilation of the residual block grows
|
||||||
|
exponentially.
|
||||||
|
residual_channels : int, optional
|
||||||
|
Residual channel of the residual blocks, by default 64
|
||||||
|
gate_channels : int, optional
|
||||||
|
Gate channel of the residual blocks, by default 128
|
||||||
|
skip_channels : int, optional
|
||||||
|
Skip channel of the residual blocks, by default 64
|
||||||
|
aux_channels : int, optional
|
||||||
|
Auxiliary channel of the residual blocks, by default 80
|
||||||
|
aux_context_window : int, optional
|
||||||
|
The context window size of the first convolution applied to the
|
||||||
|
auxiliary input, by default 2
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout of the residual blocks, by default 0.
|
||||||
|
bias : bool, optional
|
||||||
|
Whether to use bias in residual blocks, by default True
|
||||||
|
use_weight_norm : bool, optional
|
||||||
|
Whether to use weight norm in all convolutions, by default True
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding in the upsample network and residual
|
||||||
|
blocks, by default False
|
||||||
|
upsample_scales : List[int], optional
|
||||||
|
Upsample scales of the upsample network, by default [4, 4, 4, 4]
|
||||||
|
nonlinear_activation : Optional[str], optional
|
||||||
|
Non linear activation in upsample network, by default None
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters passed to the linear activation in the upsample network,
|
||||||
|
by default {}
|
||||||
|
interpolate_mode : str, optional
|
||||||
|
Interpolation mode of the upsample network, by default "nearest"
|
||||||
|
freq_axis_kernel_size : int, optional
|
||||||
|
Kernel size along the frequency axis of the upsample network, by default 1
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels=1,
|
in_channels: int=1,
|
||||||
out_channels=1,
|
out_channels: int=1,
|
||||||
kernel_size=3,
|
kernel_size: int=3,
|
||||||
layers=30,
|
layers: int=30,
|
||||||
stacks=3,
|
stacks: int=3,
|
||||||
residual_channels=64,
|
residual_channels: int=64,
|
||||||
gate_channels=128,
|
gate_channels: int=128,
|
||||||
skip_channels=64,
|
skip_channels: int=64,
|
||||||
aux_channels=80,
|
aux_channels: int=80,
|
||||||
aux_context_window=2,
|
aux_context_window: int=2,
|
||||||
dropout=0.,
|
dropout: float=0.,
|
||||||
bias=True,
|
bias: bool=True,
|
||||||
use_weight_norm=True,
|
use_weight_norm: bool=True,
|
||||||
use_causal_conv=False,
|
use_causal_conv: bool=False,
|
||||||
upsample_scales=[4, 4, 4, 4],
|
upsample_scales: List[int]=[4, 4, 4, 4],
|
||||||
nonlinear_activation=None,
|
nonlinear_activation: Optional[str]=None,
|
||||||
nonlinear_activation_params={},
|
nonlinear_activation_params: Dict[str, Any]={},
|
||||||
interpolate_mode="nearest",
|
interpolate_mode: str="nearest",
|
||||||
freq_axis_kernel_size=1):
|
freq_axis_kernel_size: int=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
@ -233,10 +438,24 @@ class PWGGenerator(nn.Layer):
|
||||||
if use_weight_norm:
|
if use_weight_norm:
|
||||||
self.apply_weight_norm()
|
self.apply_weight_norm()
|
||||||
|
|
||||||
def forward(self, x, c):
|
def forward(self, x: Tensor, c: Tensor) -> Tensor:
|
||||||
if c is not None:
|
"""Generate waveform.
|
||||||
c = self.upsample_net(c)
|
|
||||||
assert c.shape[-1] == x.shape[-1]
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, C_in, T), The input waveform.
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram). It
|
||||||
|
is upsampled to match the time resolution of the input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, C_out, T), the generated waveform.
|
||||||
|
"""
|
||||||
|
c = self.upsample_net(c)
|
||||||
|
assert c.shape[-1] == x.shape[-1]
|
||||||
|
|
||||||
x = self.first_conv(x)
|
x = self.first_conv(x)
|
||||||
skips = 0
|
skips = 0
|
||||||
|
@ -249,6 +468,10 @@ class PWGGenerator(nn.Layer):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def apply_weight_norm(self):
|
def apply_weight_norm(self):
|
||||||
|
"""Recursively apply weight normalization to all the Convolution layers
|
||||||
|
in the sublayers.
|
||||||
|
"""
|
||||||
|
|
||||||
def _apply_weight_norm(layer):
|
def _apply_weight_norm(layer):
|
||||||
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||||
nn.utils.weight_norm(layer)
|
nn.utils.weight_norm(layer)
|
||||||
|
@ -256,6 +479,10 @@ class PWGGenerator(nn.Layer):
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
|
"""Recursively remove weight normalization from all the Convolution
|
||||||
|
layers in the sublayers.
|
||||||
|
"""
|
||||||
|
|
||||||
def _remove_weight_norm(layer):
|
def _remove_weight_norm(layer):
|
||||||
try:
|
try:
|
||||||
nn.utils.remove_weight_norm(layer)
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
@ -264,17 +491,30 @@ class PWGGenerator(nn.Layer):
|
||||||
|
|
||||||
self.apply(_remove_weight_norm)
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
def inference(self, c=None, x=None):
|
def inference(self, c: Optional[Tensor]=None,
|
||||||
"""
|
x: Optional[Tensor]=None) -> Tensor:
|
||||||
single instance inference
|
"""Waveform generation. This function is used for single instance
|
||||||
c: [T', C] condition
|
inference.
|
||||||
x: [T, 1] noise
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor, optional
|
||||||
|
Shape (T', C_aux), the auxiliary input, by default None
|
||||||
|
x : Tensor, optional
|
||||||
|
Shape (T, C_in), the noise waveform, by default None
|
||||||
|
If not provided, a sample is drawn from a gaussian distribution.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (T, C_out), the generated waveform
|
||||||
"""
|
"""
|
||||||
if x is not None:
|
if x is not None:
|
||||||
x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch
|
x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch
|
||||||
else:
|
else:
|
||||||
assert c is not None
|
assert c is not None
|
||||||
x = paddle.randn([1, 1, c.shape[0] * self.upsample_factor])
|
x = paddle.randn(
|
||||||
|
[1, self.in_channels, c.shape[0] * self.upsample_factor])
|
||||||
|
|
||||||
if c is not None:
|
if c is not None:
|
||||||
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
|
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
|
||||||
|
|
Loading…
Reference in New Issue