merge upstream develop
This commit is contained in:
commit
e30d7ad48f
|
@ -141,6 +141,15 @@ class ResidualBlock(nn.Layer):
|
||||||
raise ValueError("Only use start sequence at evaluation mode.")
|
raise ValueError("Only use start sequence at evaluation mode.")
|
||||||
self._conv_buffer = None
|
self._conv_buffer = None
|
||||||
|
|
||||||
|
# NOTE: call self.conv's weight norm hook expliccitly since
|
||||||
|
# its weight will be visited directly in `add_input` without
|
||||||
|
# calling its `__call__` method. If we do not trigger the weight
|
||||||
|
# norm hook, the weight may be outdated. e.g. after loading from
|
||||||
|
# a saved checkpoint
|
||||||
|
# see also: https://github.com/pytorch/pytorch/issues/47588
|
||||||
|
for hook in self.conv._forward_pre_hooks.values():
|
||||||
|
hook(self.conv, None)
|
||||||
|
|
||||||
def add_input(self, x_row, condition_row):
|
def add_input(self, x_row, condition_row):
|
||||||
"""Compute the output for a row and update the buffer.
|
"""Compute the output for a row and update the buffer.
|
||||||
|
|
||||||
|
@ -158,10 +167,6 @@ class ResidualBlock(nn.Layer):
|
||||||
self._update_buffer(x_row)
|
self._update_buffer(x_row)
|
||||||
|
|
||||||
rw = self.rw
|
rw = self.rw
|
||||||
# call self.conv's weight norm hook expliccitly since its __call__
|
|
||||||
# method is not called here
|
|
||||||
for hook in self.conv._forward_pre_hooks.values():
|
|
||||||
hook(self.conv, self._conv_buffer)
|
|
||||||
x_row = F.conv2d(
|
x_row = F.conv2d(
|
||||||
self._conv_buffer,
|
self._conv_buffer,
|
||||||
self.conv.weight,
|
self.conv.weight,
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from __future__ import division
|
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from typing import Union, Sequence
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -25,36 +25,7 @@ import paddle.fluid.initializer as I
|
||||||
import paddle.fluid.layers.distributions as D
|
import paddle.fluid.layers.distributions as D
|
||||||
|
|
||||||
from parakeet.modules.conv import Conv1dCell
|
from parakeet.modules.conv import Conv1dCell
|
||||||
|
from parakeet.modules.audio import quantize, dequantize, STFT
|
||||||
__all__ = ["ConditionalWavenet"]
|
|
||||||
|
|
||||||
def quantize(values, n_bands):
|
|
||||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
values (Tensor): dtype: flaot32 or float64. the floating point value.
|
|
||||||
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: the quantized tensor, dtype: int64.
|
|
||||||
"""
|
|
||||||
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
|
||||||
return quantized
|
|
||||||
|
|
||||||
|
|
||||||
def dequantize(quantized, n_bands, dtype=None):
|
|
||||||
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quantized (Tensor): dtype: int64. The quantized value in the range [0, n_bands).
|
|
||||||
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: the dequantized tensor, dtype is specified by dtype.
|
|
||||||
"""
|
|
||||||
dtype = dtype or paddle.get_default_dtype()
|
|
||||||
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def crop(x, audio_start, audio_length):
|
def crop(x, audio_start, audio_length):
|
||||||
|
@ -81,9 +52,52 @@ def crop(x, audio_start, audio_length):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampleNet(nn.LayerList):
|
||||||
|
def __init__(self, upscale_factors=[16, 16]):
|
||||||
|
"""UpsamplingNet.
|
||||||
|
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose layer upsamples the time dimension by its `stride` times. And each Conv2DTranspose's filter_size at frequency dimension is 3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
upscale_factors (list[int], optional): time upsampling factors for each Conv2DTranspose Layer. The `UpsampleNet` contains len(upscale_factor) Conv2DTranspose Layers. Each upscale_factor is used as the `stride` for the corresponding Conv2DTranspose. Defaults to [16, 16].
|
||||||
|
Note:
|
||||||
|
np.prod(upscale_factors) should equals the `hop_length` of the stft transformation used to extract spectrogram features from audios. For example, 16 * 16 = 256, then the spectram extracted using a stft transformation whose `hop_length` is 256. See `librosa.stft` for more details.
|
||||||
|
"""
|
||||||
|
super(UpsampleNet, self).__init__()
|
||||||
|
self.upscale_factors = list(upscale_factors)
|
||||||
|
self.upscale_factor = 1
|
||||||
|
for item in upscale_factors:
|
||||||
|
self.upscale_factor *= item
|
||||||
|
|
||||||
|
for factor in self.upscale_factors:
|
||||||
|
self.append(
|
||||||
|
nn.utils.weight_norm(
|
||||||
|
nn.Conv2DTranspose(1, 1,
|
||||||
|
kernel_size=(3, 2 * factor),
|
||||||
|
stride=(1, factor),
|
||||||
|
padding=(1, factor // 2))))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Compute the upsampled condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): shape(B, F, T), dtype float32, the condition (mel spectrogram here.) (F means the frequency bands). In the internal Conv2DTransposes, the frequency dimension is treated as `height` dimension instead of `in_channels`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: shape(B, F, T * upscale_factor), dtype float32, the upsampled condition.
|
||||||
|
"""
|
||||||
|
x = paddle.unsqueeze(x, 1)
|
||||||
|
for sublayer in self:
|
||||||
|
x = F.leaky_relu(sublayer(x), 0.4)
|
||||||
|
x = paddle.squeeze(x, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Layer):
|
class ResidualBlock(nn.Layer):
|
||||||
def __init__(self, residual_channels, condition_dim, filter_size,
|
def __init__(self,
|
||||||
dilation):
|
residual_channels: int,
|
||||||
|
condition_dim: int,
|
||||||
|
filter_size: Union[int, Sequence[int]],
|
||||||
|
dilation: int):
|
||||||
"""A Residual block in wavenet. It does not have parametric residual or skip connection. It consists of a Conv1DCell and an Conv1D(filter_size = 1) to integrate the condition.
|
"""A Residual block in wavenet. It does not have parametric residual or skip connection. It consists of a Conv1DCell and an Conv1D(filter_size = 1) to integrate the condition.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -121,20 +135,13 @@ class ResidualBlock(nn.Layer):
|
||||||
"""Conv1D gated-tanh Block.
|
"""Conv1D gated-tanh Block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): shape(B, C_res, T), the input. (B stands for batch_size,
|
x (Tensor): shape(B, C_res, T), the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.) dtype float32.
|
||||||
C_res stands for residual channels, T stands for time steps.)
|
condition (Tensor, optional): shape(B, C_cond, T), the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels). Defaults to None.
|
||||||
dtype float32.
|
|
||||||
condition (Tensor, optional): shape(B, C_cond, T), the condition,
|
|
||||||
it has been upsampled in time steps, so it has the same time
|
|
||||||
steps as the input does.(C_cond stands for the condition's channels).
|
|
||||||
Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(residual, skip_connection)
|
(residual, skip_connection)
|
||||||
residual (Tensor): shape(B, C_res, T), the residual, which is used
|
residual (Tensor): shape(B, C_res, T), the residual, which is used as the input to the next layer of ResidualBlock.
|
||||||
as the input to the next layer of ResidualBlock.
|
skip_connection (Tensor): shape(B, C_res, T), the skip connection. This output is accumulated with that of other ResidualBlocks.
|
||||||
skip_connection (Tensor): shape(B, C_res, T), the skip connection.
|
|
||||||
This output is accumulated with that of other ResidualBlocks.
|
|
||||||
"""
|
"""
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
|
@ -155,30 +162,22 @@ class ResidualBlock(nn.Layer):
|
||||||
return residual, skip_connection
|
return residual, skip_connection
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
"""
|
"""Prepare the ResidualBlock to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||||
Prepare the ResidualBlock to generate a new sequence. This method
|
|
||||||
should be called before starting calling `add_input` multiple times.
|
|
||||||
"""
|
"""
|
||||||
self.conv.start_sequence()
|
self.conv.start_sequence()
|
||||||
self.condition_proj.start_sequence()
|
self.condition_proj.start_sequence()
|
||||||
|
|
||||||
def add_input(self, x, condition=None):
|
def add_input(self, x, condition=None):
|
||||||
"""
|
"""Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
|
||||||
Add a step input. This method works similarily with `forward` but
|
|
||||||
in a `step-in-step-out` fashion.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Variable): shape(B, C_res), input for a step, dtype float32.
|
x (Tensor): shape(B, C_res), input for a step, dtype float32.
|
||||||
condition (Variable, optional): shape(B, C_cond). condition for a
|
condition (Tensor, optional): shape(B, C_cond). condition for a step, dtype float32. Defaults to None.
|
||||||
step, dtype float32. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(residual, skip_connection)
|
(residual, skip_connection)
|
||||||
residual (Variable): shape(B, C_res), the residual for a step,
|
residual (Tensor): shape(B, C_res), the residual for a step, which is used as the input to the next layer of ResidualBlock.
|
||||||
which is used as the input to the next layer of ResidualBlock.
|
skip_connection (Tensor): shape(B, C_res), the skip connection for a step. This output is accumulated with that of other ResidualBlocks.
|
||||||
skip_connection (Variable): shape(B, C_res), the skip connection
|
|
||||||
for a step. This output is accumulated with that of other
|
|
||||||
ResidualBlocks.
|
|
||||||
"""
|
"""
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
|
@ -200,22 +199,24 @@ class ResidualBlock(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class ResidualNet(nn.LayerList):
|
class ResidualNet(nn.LayerList):
|
||||||
def __init__(self, n_loop, n_layer, residual_channels, condition_dim,
|
def __init__(self,
|
||||||
filter_size):
|
n_stack: int,
|
||||||
"""The residual network in wavenet. It consists of `n_layer` stacks,
|
n_loop: int,
|
||||||
each of which consists of `n_loop` ResidualBlocks.
|
residual_channels: int,
|
||||||
|
condition_dim: int,
|
||||||
|
filter_size: int):
|
||||||
|
"""The residual network in wavenet. It consists of `n_layer` stacks, each of which consists of `n_loop` ResidualBlocks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
n_stack (int): number of stacks in the `ResidualNet`.
|
||||||
n_loop (int): number of ResidualBlocks in a stack.
|
n_loop (int): number of ResidualBlocks in a stack.
|
||||||
n_layer (int): number of stacks in the `ResidualNet`.
|
|
||||||
residual_channels (int): channels of each `ResidualBlock`'s input.
|
residual_channels (int): channels of each `ResidualBlock`'s input.
|
||||||
condition_dim (int): channels of the condition.
|
condition_dim (int): channels of the condition.
|
||||||
filter_size (int): filter size of the internal Conv1DCell of each
|
filter_size (int): filter size of the internal Conv1DCell of each `ResidualBlock`.
|
||||||
`ResidualBlock`.
|
|
||||||
"""
|
"""
|
||||||
super(ResidualNet, self).__init__()
|
super(ResidualNet, self).__init__()
|
||||||
# double the dilation at each layer in a loop(n_loop layers)
|
# double the dilation at each layer in a stack
|
||||||
dilations = [2**i for i in range(n_loop)] * n_layer
|
dilations = [2**i for i in range(n_loop)] * n_stack
|
||||||
self.context_size = 1 + sum(dilations)
|
self.context_size = 1 + sum(dilations)
|
||||||
for dilation in dilations:
|
for dilation in dilations:
|
||||||
self.append(ResidualBlock(residual_channels, condition_dim, filter_size, dilation))
|
self.append(ResidualBlock(residual_channels, condition_dim, filter_size, dilation))
|
||||||
|
@ -223,13 +224,8 @@ class ResidualNet(nn.LayerList):
|
||||||
def forward(self, x, condition=None):
|
def forward(self, x, condition=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): shape(B, C_res, T), dtype float32, the input.
|
x (Tensor): shape(B, C_res, T), dtype float32, the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.)
|
||||||
(B stands for batch_size, C_res stands for residual channels,
|
condition (Tensor, optional): shape(B, C_cond, T), dtype float32, the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels) Defaults to None.
|
||||||
T stands for time steps.)
|
|
||||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32,
|
|
||||||
the condition, it has been upsampled in time steps, so it has
|
|
||||||
the same time steps as the input does.(C_cond stands for the
|
|
||||||
condition's channels) Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
skip_connection (Tensor): shape(B, C_res, T), dtype float32, the output.
|
skip_connection (Tensor): shape(B, C_res, T), dtype float32, the output.
|
||||||
|
@ -244,24 +240,20 @@ class ResidualNet(nn.LayerList):
|
||||||
return skip_connections
|
return skip_connections
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
"""Prepare the ResidualNet to generate a new sequence. This method
|
"""Prepare the ResidualNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||||
should be called before starting calling `add_input` multiple times.
|
|
||||||
"""
|
"""
|
||||||
for block in self:
|
for block in self:
|
||||||
block.start_sequence()
|
block.start_sequence()
|
||||||
|
|
||||||
def add_input(self, x, condition=None):
|
def add_input(self, x, condition=None):
|
||||||
"""Add a step input. This method works similarily with `forward` but
|
"""Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
|
||||||
in a `step-in-step-out` fashion.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): shape(B, C_res), dtype float32, input for a step.
|
x (Tensor): shape(B, C_res), dtype float32, input for a step.
|
||||||
condition (Tensor, optional): shape(B, C_cond), dtype float32,
|
condition (Tensor, optional): shape(B, C_cond), dtype float32, condition for a step. Defaults to None.
|
||||||
condition for a step. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
skip_connection (Tensor): shape(B, C_res), dtype float32, the
|
skip_connection (Tensor): shape(B, C_res), dtype float32, the output for a step.
|
||||||
output for a step.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for i, func in enumerate(self):
|
for i, func in enumerate(self):
|
||||||
|
@ -275,31 +267,19 @@ class ResidualNet(nn.LayerList):
|
||||||
|
|
||||||
|
|
||||||
class WaveNet(nn.Layer):
|
class WaveNet(nn.Layer):
|
||||||
def __init__(self, n_loop, n_layer, residual_channels, output_dim,
|
def __init__(self, n_stack, n_loop, residual_channels, output_dim,
|
||||||
condition_dim, filter_size, loss_type, log_scale_min):
|
condition_dim, filter_size, loss_type, log_scale_min):
|
||||||
"""Wavenet that transform upsampled mel spectrogram into waveform.
|
"""Wavenet that transform upsampled mel spectrogram into waveform.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
n_stack (int): n_stack for the internal ResidualNet.
|
||||||
n_loop (int): n_loop for the internal ResidualNet.
|
n_loop (int): n_loop for the internal ResidualNet.
|
||||||
n_layer (int): n_loop for the internal ResidualNet.
|
|
||||||
residual_channels (int): the channel of the input.
|
residual_channels (int): the channel of the input.
|
||||||
output_dim (int): the channel of the output distribution.
|
output_dim (int): the channel of the output distribution.
|
||||||
condition_dim (int): the channel of the condition.
|
condition_dim (int): the channel of the condition.
|
||||||
filter_size (int): the filter size of the internal ResidualNet.
|
filter_size (int): the filter size of the internal ResidualNet.
|
||||||
loss_type (str): loss type of the wavenet. Possible values are
|
loss_type (str): loss type of the wavenet. Possible values are 'softmax' and 'mog'. If `loss_type` is 'softmax', the output is the logits of the catrgotical(multinomial) distribution, `output_dim` means the number of classes of the categorical distribution. If `loss_type` is mog(mixture of gaussians), the output is the parameters of a mixture of gaussians, which consists of weight(in the form of logit) of each gaussian distribution and its mean and log standard deviaton. So when `loss_type` is 'mog', `output_dim` should be perfectly divided by 3.
|
||||||
'softmax' and 'mog'.
|
log_scale_min (int): the minimum value of log standard deviation of the output gaussian distributions. Note that this value is only used for computing loss if `loss_type` is 'mog', values less than `log_scale_min` is clipped when computing loss.
|
||||||
If `loss_type` is 'softmax', the output is the logits of the
|
|
||||||
catrgotical(multinomial) distribution, `output_dim` means the
|
|
||||||
number of classes of the categorical distribution.
|
|
||||||
If `loss_type` is mog(mixture of gaussians), the output is the
|
|
||||||
parameters of a mixture of gaussians, which consists of weight
|
|
||||||
(in the form of logit) of each gaussian distribution and its
|
|
||||||
mean and log standard deviaton. So when `loss_type` is 'mog',
|
|
||||||
`output_dim` should be perfectly divided by 3.
|
|
||||||
log_scale_min (int): the minimum value of log standard deviation
|
|
||||||
of the output gaussian distributions. Note that this value is
|
|
||||||
only used for computing loss if `loss_type` is 'mog', values
|
|
||||||
less than `log_scale_min` is clipped when computing loss.
|
|
||||||
"""
|
"""
|
||||||
super(WaveNet, self).__init__()
|
super(WaveNet, self).__init__()
|
||||||
if loss_type not in ["softmax", "mog"]:
|
if loss_type not in ["softmax", "mog"]:
|
||||||
|
@ -312,7 +292,7 @@ class WaveNet(nn.Layer):
|
||||||
"with Mixture of Gaussians(mog) output, the output dim must be divisible by 3, but get {}".format(output_dim))
|
"with Mixture of Gaussians(mog) output, the output dim must be divisible by 3, but get {}".format(output_dim))
|
||||||
self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=-1)
|
self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=-1)
|
||||||
|
|
||||||
self.resnet = ResidualNet(n_loop, n_layer, residual_channels,
|
self.resnet = ResidualNet(n_stack, n_loop, residual_channels,
|
||||||
condition_dim, filter_size)
|
condition_dim, filter_size)
|
||||||
self.context_size = self.resnet.context_size
|
self.context_size = self.resnet.context_size
|
||||||
|
|
||||||
|
@ -334,12 +314,10 @@ class WaveNet(nn.Layer):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): shape(B, T), dtype float32, the input waveform.
|
x (Tensor): shape(B, T), dtype float32, the input waveform.
|
||||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32,
|
condition (Tensor, optional): shape(B, C_cond, T), dtype float32, the upsampled condition. Defaults to None.
|
||||||
the upsampled condition. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T, C_output), dtype float32, the parameter of
|
Tensor: shape(B, T, C_output), dtype float32, the parameter of the output distributions.
|
||||||
the output distributions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Causal Conv
|
# Causal Conv
|
||||||
|
@ -362,24 +340,19 @@ class WaveNet(nn.Layer):
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
"""Prepare the WaveNet to generate a new sequence. This method should
|
"""Prepare the WaveNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||||
be called before starting calling `add_input` multiple times.
|
|
||||||
"""
|
"""
|
||||||
self.resnet.start_sequence()
|
self.resnet.start_sequence()
|
||||||
|
|
||||||
def add_input(self, x, condition=None):
|
def add_input(self, x, condition=None):
|
||||||
"""compute the output distribution (represented by its parameters) for
|
"""compute the output distribution (represented by its parameters) for a step. It works similarily with the `forward` method but in a `step-in-step-out` fashion.
|
||||||
a step. It works similarily with the `forward` method but in a
|
|
||||||
`step-in-step-out` fashion.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): shape(B,), dtype float32, a step of the input waveform.
|
x (Tensor): shape(B,), dtype float32, a step of the input waveform.
|
||||||
condition (Tensor, optional): shape(B, C_cond, ), dtype float32, a
|
condition (Tensor, optional): shape(B, C_cond, ), dtype float32, a step of the upsampled condition. Defaults to None.
|
||||||
step of the upsampled condition. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, C_output), dtype float32, the parameter of the
|
Tensor: shape(B, C_output), dtype float32, the parameter of the output distributions.
|
||||||
output distributions.
|
|
||||||
"""
|
"""
|
||||||
# Causal Conv
|
# Causal Conv
|
||||||
if self.loss_type == "softmax":
|
if self.loss_type == "softmax":
|
||||||
|
@ -402,12 +375,8 @@ class WaveNet(nn.Layer):
|
||||||
"""compute the loss where output distribution is a categorial distribution.
|
"""compute the loss where output distribution is a categorial distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, the logits of the
|
y (Tensor): shape(B, T, C_output), dtype float32, the logits of the output distribution.
|
||||||
output distribution.
|
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that
|
|
||||||
the target's corresponding time index is one step ahead of the
|
|
||||||
output distribution. And output distribution whose input contains
|
|
||||||
padding is neglected in loss computation.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(1, ), dtype float32, the loss.
|
Tensor: shape(1, ), dtype float32, the loss.
|
||||||
|
@ -420,15 +389,14 @@ class WaveNet(nn.Layer):
|
||||||
label = paddle.unsqueeze(quantized, -1)
|
label = paddle.unsqueeze(quantized, -1)
|
||||||
|
|
||||||
loss = F.softmax_with_cross_entropy(y, label)
|
loss = F.softmax_with_cross_entropy(y, label)
|
||||||
reduced_loss = paddle.reduce_mean(loss)
|
reduced_loss = paddle.mean(loss)
|
||||||
return reduced_loss
|
return reduced_loss
|
||||||
|
|
||||||
def sample_from_softmax(self, y):
|
def sample_from_softmax(self, y):
|
||||||
"""Sample from the output distribution where the output distribution is
|
"""Sample from the output distribution where the output distribution is a categorical distriobution.
|
||||||
a categorical distriobution.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), the logits of the output distribution.
|
y (Tensor): shape(B, T, C_output), the logits of the output distribution
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||||
|
@ -446,16 +414,8 @@ class WaveNet(nn.Layer):
|
||||||
"""compute the loss where output distribution is a mixture of Gaussians.
|
"""compute the loss where output distribution is a mixture of Gaussians.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
|
||||||
the output distribution. It is the concatenation of 3 parts,
|
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||||
the logits of every distribution, the mean of each distribution
|
|
||||||
and the log standard deviation of each distribution. Each part's
|
|
||||||
shape is (B, T, n_mixture), where `n_mixture` means the number
|
|
||||||
of Gaussians in the mixture.
|
|
||||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that
|
|
||||||
the target's corresponding time index is one step ahead of the
|
|
||||||
output distribution. And output distribution whose input contains
|
|
||||||
padding is neglected in loss computation.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(1, ), dtype float32, the loss.
|
Tensor: shape(1, ), dtype float32, the loss.
|
||||||
|
@ -483,22 +443,16 @@ class WaveNet(nn.Layer):
|
||||||
|
|
||||||
pdf_x = p_mixture * pdf_x
|
pdf_x = p_mixture * pdf_x
|
||||||
# pdf_x: [bs, len]
|
# pdf_x: [bs, len]
|
||||||
pdf_x = paddle.reduce_sum(pdf_x, -1)
|
pdf_x = paddle.sum(pdf_x, -1)
|
||||||
per_sample_loss = -paddle.log(pdf_x + 1e-9)
|
per_sample_loss = -paddle.log(pdf_x + 1e-9)
|
||||||
|
|
||||||
loss = paddle.reduce_mean(per_sample_loss)
|
loss = paddle.mean(per_sample_loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def sample_from_mog(self, y):
|
def sample_from_mog(self, y):
|
||||||
"""Sample from the output distribution where the output distribution is
|
"""Sample from the output distribution where the output distribution is a mixture of Gaussians.
|
||||||
a mixture of Gaussians.
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
|
||||||
the output distribution. It is the concatenation of 3 parts, the
|
|
||||||
logits of every distribution, the mean of each distribution and the
|
|
||||||
log standard deviation of each distribution. Each part's shape is
|
|
||||||
(B, T, n_mixture), where `n_mixture` means the number of Gaussians
|
|
||||||
in the mixture.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||||
|
@ -529,8 +483,7 @@ class WaveNet(nn.Layer):
|
||||||
def sample(self, y):
|
def sample(self, y):
|
||||||
"""Sample from the output distribution.
|
"""Sample from the output distribution.
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
|
||||||
the output distribution.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||||
|
@ -544,12 +497,8 @@ class WaveNet(nn.Layer):
|
||||||
"""compute the loss where output distribution is a mixture of Gaussians.
|
"""compute the loss where output distribution is a mixture of Gaussians.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
|
||||||
the output distribution.
|
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that
|
|
||||||
the target's corresponding time index is one step ahead of the
|
|
||||||
output distribution. And output distribution whose input contains
|
|
||||||
padding is neglected in loss computation.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(1, ), dtype float32, the loss.
|
Tensor: shape(1, ), dtype float32, the loss.
|
||||||
|
@ -560,64 +509,9 @@ class WaveNet(nn.Layer):
|
||||||
return self.compute_mog_loss(y, t)
|
return self.compute_mog_loss(y, t)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleNet(nn.LayerList):
|
|
||||||
def __init__(self, upscale_factors=[16, 16]):
|
|
||||||
"""UpsamplingNet.
|
|
||||||
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose
|
|
||||||
layer upsamples the time dimension by its `stride` times. And each
|
|
||||||
Conv2DTranspose's filter_size at frequency dimension is 3.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
upscale_factors (list[int], optional): time upsampling factors for
|
|
||||||
each Conv2DTranspose Layer. The `UpsampleNet` contains
|
|
||||||
len(upscale_factor) Conv2DTranspose Layers. Each upscale_factor
|
|
||||||
is used as the `stride` for the corresponding Conv2DTranspose.
|
|
||||||
Defaults to [16, 16].
|
|
||||||
Note:
|
|
||||||
np.prod(upscale_factors) should equals the `hop_length` of the stft
|
|
||||||
transformation used to extract spectrogram features from audios.
|
|
||||||
For example, 16 * 16 = 256, then the spectram extracted using a
|
|
||||||
stft transformation whose `hop_length` is 256. See `librosa.stft`
|
|
||||||
for more details.
|
|
||||||
"""
|
|
||||||
super(UpsampleNet, self).__init__()
|
|
||||||
self.upscale_factors = list(upscale_factors)
|
|
||||||
self.upscale_factor = 1
|
|
||||||
for item in upscale_factors:
|
|
||||||
self.upscale_factor *= item
|
|
||||||
|
|
||||||
for factor in self.upscale_factors:
|
|
||||||
self.append(
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.ConvTranspose2d(1, 1,
|
|
||||||
kernel_size=(3, 2 * factor),
|
|
||||||
stride=(1, factor),
|
|
||||||
padding=(1, factor // 2))))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Compute the upsampled condition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): shape(B, F, T), dtype float32, the condition
|
|
||||||
(mel spectrogram here.) (F means the frequency bands). In the
|
|
||||||
internal Conv2DTransposes, the frequency dimension is treated
|
|
||||||
as `height` dimension instead of `in_channels`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: shape(B, F, T * upscale_factor), dtype float32, the
|
|
||||||
upsampled condition.
|
|
||||||
"""
|
|
||||||
x = paddle.unsqueeze(x, 1)
|
|
||||||
for sublayer in self:
|
|
||||||
x = F.leaky_relu(sublayer(x), 0.4)
|
|
||||||
x = paddle.squeeze(x, 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ConditionalWavenet(nn.Layer):
|
class ConditionalWavenet(nn.Layer):
|
||||||
def __init__(self, encoder, decoder):
|
def __init__(self, encoder, decoder):
|
||||||
"""Conditional Wavenet, which contains an UpsampleNet as the encoder
|
"""Conditional Wavenet, which contains an UpsampleNet as the encoder and a WaveNet as the decoder. It is an autoregressive model.
|
||||||
and a WaveNet as the decoder. It is an autoregressive model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder (UpsampleNet): the UpsampleNet as the encoder.
|
encoder (UpsampleNet): the UpsampleNet as the encoder.
|
||||||
|
@ -628,20 +522,15 @@ class ConditionalWavenet(nn.Layer):
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
def forward(self, audio, mel, audio_start):
|
def forward(self, audio, mel, audio_start):
|
||||||
"""Compute the output distribution given the mel spectrogram and the
|
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
||||||
input(for teacher force training).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio (Tensor): shape(B, T_audio), dtype float32, ground truth
|
audio (Tensor): shape(B, T_audio), dtype float32, ground truth waveform, used for teacher force training.
|
||||||
waveform, used for teacher force training.
|
mel (Tensor): shape(B, F, T_mel), dtype float32, mel spectrogram. Note that it is the spectrogram for the whole utterance.
|
||||||
mel (Tensor): shape(B, F, T_mel), dtype float32, mel spectrogram.
|
audio_start (Tensor): shape(B, ), dtype: int, audio slices' start positions for each utterance.
|
||||||
Note that it is the spectrogram for the whole utterance.
|
|
||||||
audio_start (Tensor): shape(B, ), dtype: int, audio slices' start
|
|
||||||
positions for each utterance.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T_audio - 1, C_putput), parameters for the output
|
Tensor: shape(B, T_audio - 1, C_putput), parameters for the output distribution.(C_output is the `output_dim` of the decoder.)
|
||||||
distribution.(C_output is the `output_dim` of the decoder.)
|
|
||||||
"""
|
"""
|
||||||
audio_length = audio.shape[1] # audio clip's length
|
audio_length = audio.shape[1] # audio clip's length
|
||||||
condition = self.encoder(mel)
|
condition = self.encoder(mel)
|
||||||
|
@ -655,12 +544,10 @@ class ConditionalWavenet(nn.Layer):
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def loss(self, y, t):
|
def loss(self, y, t):
|
||||||
"""compute loss with respect to the output distribution and the targer
|
"""compute loss with respect to the output distribution and the targer audio.
|
||||||
audio.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T - 1, C_output), dtype float32, parameters of
|
y (Tensor): shape(B, T - 1, C_output), dtype float32, parameters of the output distribution.
|
||||||
the output distribution.
|
|
||||||
t (Tensor): shape(B, T), dtype float32, target waveform.
|
t (Tensor): shape(B, T), dtype float32, target waveform.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -674,12 +561,10 @@ class ConditionalWavenet(nn.Layer):
|
||||||
"""Sample from the output distribution.
|
"""Sample from the output distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, parameters of the
|
y (Tensor): shape(B, T, C_output), dtype float32, parameters of the output distribution.
|
||||||
output distribution.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T), dtype float32, sampled waveform from the output
|
Tensor: shape(B, T), dtype float32, sampled waveform from the output distribution.
|
||||||
distribution.
|
|
||||||
"""
|
"""
|
||||||
samples = self.decoder.sample(y)
|
samples = self.decoder.sample(y)
|
||||||
return samples
|
return samples
|
||||||
|
@ -692,9 +577,7 @@ class ConditionalWavenet(nn.Layer):
|
||||||
mel (Tensor): shape(B, F, T), condition(mel spectrogram here).
|
mel (Tensor): shape(B, F, T), condition(mel spectrogram here).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, T * upsacle_factor), synthesized waveform.
|
Tensor: shape(B, T * upsacle_factor), synthesized waveform.(`upscale_factor` is the `upscale_factor` of the encoder `UpsampleNet`)
|
||||||
(`upscale_factor` is the `upscale_factor` of the encoder
|
|
||||||
`UpsampleNet`)
|
|
||||||
"""
|
"""
|
||||||
condition = self.encoder(mel)
|
condition = self.encoder(mel)
|
||||||
batch_size, _, time_steps = condition.shape
|
batch_size, _, time_steps = condition.shape
|
||||||
|
@ -712,6 +595,3 @@ class ConditionalWavenet(nn.Layer):
|
||||||
|
|
||||||
samples = paddle.concat(samples, -1)
|
samples = paddle.concat(samples, -1)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
# TODO WaveNetLoss
|
|
|
@ -4,6 +4,38 @@ from paddle.nn import functional as F
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = ["quantize", "dequantize", "STFT"]
|
||||||
|
|
||||||
|
|
||||||
|
def quantize(values, n_bands):
|
||||||
|
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values (Tensor): dtype: flaot32 or float64. the floating point value.
|
||||||
|
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: the quantized tensor, dtype: int64.
|
||||||
|
"""
|
||||||
|
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
||||||
|
return quantized
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize(quantized, n_bands, dtype=None):
|
||||||
|
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quantized (Tensor): dtype: int64. The quantized value in the range [0, n_bands).
|
||||||
|
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
|
||||||
|
dtype (str, optional): data type of the output.
|
||||||
|
Returns:
|
||||||
|
Tensor: the dequantized tensor, dtype is specified by dtype.
|
||||||
|
"""
|
||||||
|
dtype = dtype or paddle.get_default_dtype()
|
||||||
|
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class STFT(nn.Layer):
|
class STFT(nn.Layer):
|
||||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||||
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
|
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
|
|
@ -60,6 +60,14 @@ class Conv1dCell(nn.Conv1D):
|
||||||
if self.training:
|
if self.training:
|
||||||
raise Exception("only use start_sequence in evaluation")
|
raise Exception("only use start_sequence in evaluation")
|
||||||
self._buffer = None
|
self._buffer = None
|
||||||
|
|
||||||
|
# NOTE: call self's weight norm hook expliccitly since self.weight
|
||||||
|
# is visited directly in this method without calling self.__call__
|
||||||
|
# method. If we do not trigger the weight norm hook, the weight
|
||||||
|
# may be outdated. e.g. after loading from a saved checkpoint
|
||||||
|
# see also: https://github.com/pytorch/pytorch/issues/47588
|
||||||
|
for hook in self._forward_pre_hooks.values():
|
||||||
|
hook(self, None)
|
||||||
self._reshaped_weight = paddle.reshape(self.weight,
|
self._reshaped_weight = paddle.reshape(self.weight,
|
||||||
(self._out_channels, -1))
|
(self._out_channels, -1))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue