waveflow: explicitly call forward hook before calling a method other than forward when needed.
This commit is contained in:
parent
af4da7dd9e
commit
a9177cd6c2
|
@ -1,18 +1,13 @@
|
||||||
import math
|
import math
|
||||||
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
from paddle.nn import initializer as I
|
from paddle.nn import initializer as I
|
||||||
|
|
||||||
from typing import Sequence
|
|
||||||
from parakeet.modules import geometry as geo
|
from parakeet.modules import geometry as geo
|
||||||
|
|
||||||
import itertools
|
__all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
|
||||||
import numpy as np
|
|
||||||
import paddle.fluid.dygraph as dg
|
|
||||||
from paddle import fluid
|
|
||||||
|
|
||||||
__all__ = ["WaveFlow"]
|
|
||||||
|
|
||||||
def fold(x, n_group):
|
def fold(x, n_group):
|
||||||
"""Fold audio or spectrogram's temporal dimension in to groups.
|
"""Fold audio or spectrogram's temporal dimension in to groups.
|
||||||
|
@ -34,7 +29,7 @@ class UpsampleNet(nn.LayerList):
|
||||||
the corresponding waveform. It consists of several conv2dtranspose layers
|
the corresponding waveform. It consists of several conv2dtranspose layers
|
||||||
which perform de convolution on mel and time dimension.
|
which perform de convolution on mel and time dimension.
|
||||||
"""
|
"""
|
||||||
def __init__(self, upsample_factors: Sequence[int]):
|
def __init__(self, upsample_factors):
|
||||||
super(UpsampleNet, self).__init__()
|
super(UpsampleNet, self).__init__()
|
||||||
for factor in upsample_factors:
|
for factor in upsample_factors:
|
||||||
std = math.sqrt(1 / (3 * 2 * factor))
|
std = math.sqrt(1 / (3 * 2 * factor))
|
||||||
|
@ -60,18 +55,18 @@ class UpsampleNet(nn.LayerList):
|
||||||
each layer. Defaults to False.
|
each layer. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(batch_size, input_channels, time_steps * upsample_factors).
|
Tensor: shape(batch_size, input_channels, time_steps * upsample_factor).
|
||||||
If trim_conv_artifact is True, the output time steps is less
|
If trim_conv_artifact is True, the output time steps is less
|
||||||
than time_steps * upsample_factors.
|
than time_steps * upsample_factors.
|
||||||
"""
|
"""
|
||||||
x = paddle.unsqueeze(x, 1)
|
x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
|
||||||
for layer in self:
|
for layer in self:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
if trim_conv_artifact:
|
if trim_conv_artifact:
|
||||||
time_cutoff = layer._kernel_size[1] - layer._stride[1]
|
time_cutoff = layer._kernel_size[1] - layer._stride[1]
|
||||||
x = x[:, :, :, -time_cutoff]
|
x = x[:, :, :, :-time_cutoff]
|
||||||
x = F.leaky_relu(x, 0.4)
|
x = F.leaky_relu(x, 0.4)
|
||||||
x = paddle.squeeze(x, 1)
|
x = paddle.squeeze(x, 1) # back to (B, C, T)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,6 +120,7 @@ class ResidualBlock(nn.Layer):
|
||||||
res (Tensor): shape(batch_size, channel, height, width), the residual output.
|
res (Tensor): shape(batch_size, channel, height, width), the residual output.
|
||||||
res (Tensor): shape(batch_size, channel, height, width), the skip output.
|
res (Tensor): shape(batch_size, channel, height, width), the skip output.
|
||||||
"""
|
"""
|
||||||
|
x_in = x
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x += self.condition_proj(condition)
|
x += self.condition_proj(condition)
|
||||||
|
|
||||||
|
@ -133,7 +129,7 @@ class ResidualBlock(nn.Layer):
|
||||||
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
res, skip = paddle.chunk(x, 2, axis=1)
|
res, skip = paddle.chunk(x, 2, axis=1)
|
||||||
return res, skip
|
return x_in + res, skip
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
|
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
|
||||||
|
@ -156,11 +152,16 @@ class ResidualBlock(nn.Layer):
|
||||||
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
|
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
|
||||||
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
|
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
|
||||||
"""
|
"""
|
||||||
|
x_row_in = x_row
|
||||||
if self._conv_buffer is None:
|
if self._conv_buffer is None:
|
||||||
self._init_buffer(x_row)
|
self._init_buffer(x_row)
|
||||||
self._update_buffer(x_row)
|
self._update_buffer(x_row)
|
||||||
|
|
||||||
rw = self.rw
|
rw = self.rw
|
||||||
|
# call self.conv's weight norm hook expliccitly since its __call__
|
||||||
|
# method is not called here
|
||||||
|
for hook in self.conv._forward_pre_hooks.values():
|
||||||
|
hook(self.conv, self._conv_buffer)
|
||||||
x_row = F.conv2d(
|
x_row = F.conv2d(
|
||||||
self._conv_buffer,
|
self._conv_buffer,
|
||||||
self.conv.weight,
|
self.conv.weight,
|
||||||
|
@ -174,7 +175,7 @@ class ResidualBlock(nn.Layer):
|
||||||
|
|
||||||
x_row = self.out_proj(x_row)
|
x_row = self.out_proj(x_row)
|
||||||
res, skip = paddle.chunk(x_row, 2, axis=1)
|
res, skip = paddle.chunk(x_row, 2, axis=1)
|
||||||
return res, skip
|
return x_row_in + res, skip
|
||||||
|
|
||||||
def _init_buffer(self, input):
|
def _init_buffer(self, input):
|
||||||
batch_size, channels, _, width = input.shape
|
batch_size, channels, _, width = input.shape
|
||||||
|
@ -262,7 +263,7 @@ class Flow(nn.Layer):
|
||||||
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
||||||
super(Flow, self).__init__()
|
super(Flow, self).__init__()
|
||||||
# input projection
|
# input projection
|
||||||
self.first_conv = nn.utils.weight_norm(
|
self.input_proj = nn.utils.weight_norm(
|
||||||
nn.Conv2D(1, channels, (1, 1),
|
nn.Conv2D(1, channels, (1, 1),
|
||||||
weight_attr=I.Uniform(-1., 1.),
|
weight_attr=I.Uniform(-1., 1.),
|
||||||
bias_attr=I.Uniform(-1., 1.)))
|
bias_attr=I.Uniform(-1., 1.)))
|
||||||
|
@ -272,18 +273,17 @@ class Flow(nn.Layer):
|
||||||
self.dilations_dict[n_group])
|
self.dilations_dict[n_group])
|
||||||
|
|
||||||
# output projection
|
# output projection
|
||||||
self.last_conv = nn.utils.weight_norm(
|
self.output_proj = nn.Conv2D(channels, 2, (1, 1),
|
||||||
nn.Conv2D(channels, 2, (1, 1),
|
weight_attr=I.Constant(0.),
|
||||||
weight_attr=I.Constant(0.),
|
bias_attr=I.Constant(0.))
|
||||||
bias_attr=I.Constant(0.)))
|
|
||||||
|
|
||||||
# specs
|
# specs
|
||||||
self.n_group = n_group
|
self.n_group = n_group
|
||||||
|
|
||||||
def _predict_parameters(self, x, condition):
|
def _predict_parameters(self, x, condition):
|
||||||
x = self.first_conv(x)
|
x = self.input_proj(x)
|
||||||
x = self.resnet(x, condition)
|
x = self.resnet(x, condition)
|
||||||
bijection_params = self.last_conv(x)
|
bijection_params = self.output_proj(x)
|
||||||
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||||
return logs, b
|
return logs, b
|
||||||
|
|
||||||
|
@ -314,14 +314,14 @@ class Flow(nn.Layer):
|
||||||
return z, (logs, b)
|
return z, (logs, b)
|
||||||
|
|
||||||
def _predict_row_parameters(self, x_row, condition_row):
|
def _predict_row_parameters(self, x_row, condition_row):
|
||||||
x_row = self.first_conv(x_row)
|
x_row = self.input_proj(x_row)
|
||||||
x_row = self.resnet.add_input(x_row, condition_row)
|
x_row = self.resnet.add_input(x_row, condition_row)
|
||||||
bijection_params = self.last_conv(x_row)
|
bijection_params = self.output_proj(x_row)
|
||||||
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||||
return logs, b
|
return logs, b
|
||||||
|
|
||||||
def _inverse_transform_row(self, z_row, logs, b):
|
def _inverse_transform_row(self, z_row, logs, b):
|
||||||
x_row = (z_row - b) / paddle.exp(logs)
|
x_row = (z_row - b) * paddle.exp(-logs)
|
||||||
return x_row
|
return x_row
|
||||||
|
|
||||||
def _inverse_row(self, z_row, x_row, condition_row):
|
def _inverse_row(self, z_row, x_row, condition_row):
|
||||||
|
@ -329,7 +329,7 @@ class Flow(nn.Layer):
|
||||||
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
||||||
return x_next_row, (logs, b)
|
return x_next_row, (logs, b)
|
||||||
|
|
||||||
def start_sequence(self):
|
def _start_sequence(self):
|
||||||
self.resnet.start_sequence()
|
self.resnet.start_sequence()
|
||||||
|
|
||||||
def inverse(self, z, condition):
|
def inverse(self, z, condition):
|
||||||
|
@ -352,9 +352,9 @@ class Flow(nn.Layer):
|
||||||
b_list = []
|
b_list = []
|
||||||
x.append(z_0)
|
x.append(z_0)
|
||||||
|
|
||||||
self.start_sequence()
|
self._start_sequence()
|
||||||
for i in range(1, self.n_group):
|
for i in range(1, self.n_group):
|
||||||
x_row = x[-1] # actuallt i-1
|
x_row = x[-1] # actuallt i-1:i
|
||||||
z_row = z[:, :, i:i+1, :]
|
z_row = z[:, :, i:i+1, :]
|
||||||
condition_row = condition[:, :, i:i+1, :]
|
condition_row = condition[:, :, i:i+1, :]
|
||||||
|
|
||||||
|
@ -368,7 +368,7 @@ class Flow(nn.Layer):
|
||||||
b = paddle.concat(b_list, 2)
|
b = paddle.concat(b_list, 2)
|
||||||
return x, (logs, b)
|
return x, (logs, b)
|
||||||
|
|
||||||
|
|
||||||
class WaveFlow(nn.LayerList):
|
class WaveFlow(nn.LayerList):
|
||||||
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
|
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
|
||||||
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
||||||
|
@ -443,30 +443,24 @@ class WaveFlow(nn.LayerList):
|
||||||
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
|
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
|
||||||
return z, log_det_jacobian
|
return z, log_det_jacobian
|
||||||
|
|
||||||
def start_sequence(self):
|
|
||||||
for layer in self:
|
|
||||||
layer.start_sequence()
|
|
||||||
|
|
||||||
def inverse(self, z, condition):
|
def inverse(self, z, condition):
|
||||||
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
||||||
and transform the sample. It is a auto regressive transformation.
|
and transform the sample. It is a auto regressive transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
|
z (Tensor): shape(batch, 1, time_steps), a input sample of the distribution p(Z).
|
||||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
condition (Tensor): shape(batch, condition_channel, time_steps), the local condition.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
||||||
"""
|
"""
|
||||||
self.start_sequence()
|
|
||||||
|
|
||||||
z, condition = self._trim(z, condition)
|
z, condition = self._trim(z, condition)
|
||||||
# to (B, C, h, T//h) layout
|
# to (B, C, h, T//h) layout
|
||||||
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
||||||
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
||||||
|
|
||||||
# reverse it flow by flow
|
# reverse it flow by flow
|
||||||
self.n_flows
|
|
||||||
for i in reversed(range(self.n_flows)):
|
for i in reversed(range(self.n_flows)):
|
||||||
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||||
|
@ -478,9 +472,29 @@ class WaveFlow(nn.LayerList):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalWaveFlow(nn.LayerList):
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super(ConditionalWaveFlow, self).__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
def forward(self, audio, mel):
|
||||||
|
condition = self.encoder(mel)
|
||||||
|
z, log_det_jacobian = self.decoder(audio, condition)
|
||||||
|
return z, log_det_jacobian
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def synthesize(self, mel):
|
||||||
|
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
||||||
|
batch_size, _, time_steps = condition.shape
|
||||||
|
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
||||||
|
x = self.decoder.inverse(z, condition)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class WaveFlowLoss(nn.Layer):
|
class WaveFlowLoss(nn.Layer):
|
||||||
def __init__(self, sigma=1.0):
|
def __init__(self, sigma=1.0):
|
||||||
super().__init__()
|
super(WaveFlowLoss, self).__init__()
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
||||||
|
|
||||||
|
@ -489,4 +503,4 @@ class WaveFlowLoss(nn.Layer):
|
||||||
|
|
||||||
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
|
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
|
||||||
loss = loss / np.prod(z.shape)
|
loss = loss / np.prod(z.shape)
|
||||||
return loss + self.const
|
return loss + self.const
|
||||||
|
|
|
@ -83,7 +83,7 @@ class TestFlow(unittest.TestCase):
|
||||||
def test_inverse_row(self):
|
def test_inverse_row(self):
|
||||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||||
net.eval()
|
net.eval()
|
||||||
net.start_sequence()
|
net._start_sequence()
|
||||||
|
|
||||||
x_row = paddle.randn([4, 1, 1, 32]) # last row
|
x_row = paddle.randn([4, 1, 1, 32]) # last row
|
||||||
condition_row = paddle.randn([4, 7, 1, 32])
|
condition_row = paddle.randn([4, 7, 1, 32])
|
||||||
|
@ -97,7 +97,6 @@ class TestFlow(unittest.TestCase):
|
||||||
def test_inverse(self):
|
def test_inverse(self):
|
||||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||||
net.eval()
|
net.eval()
|
||||||
net.start_sequence()
|
|
||||||
|
|
||||||
z = paddle.randn([4, 1, 8, 32])
|
z = paddle.randn([4, 1, 8, 32])
|
||||||
condition = paddle.randn([4, 7, 8, 32])
|
condition = paddle.randn([4, 7, 8, 32])
|
||||||
|
@ -129,6 +128,3 @@ class TestWaveFlow(unittest.TestCase):
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
x = net.inverse(z, condition)
|
x = net.inverse(z, condition)
|
||||||
self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))
|
self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue