waveflow: explicitly call forward hook before calling a method other than forward when needed.
This commit is contained in:
parent
af4da7dd9e
commit
a9177cd6c2
|
@ -1,18 +1,13 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from typing import Sequence
|
||||
from parakeet.modules import geometry as geo
|
||||
|
||||
import itertools
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
|
||||
__all__ = ["WaveFlow"]
|
||||
__all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
|
||||
|
||||
def fold(x, n_group):
|
||||
"""Fold audio or spectrogram's temporal dimension in to groups.
|
||||
|
@ -34,7 +29,7 @@ class UpsampleNet(nn.LayerList):
|
|||
the corresponding waveform. It consists of several conv2dtranspose layers
|
||||
which perform de convolution on mel and time dimension.
|
||||
"""
|
||||
def __init__(self, upsample_factors: Sequence[int]):
|
||||
def __init__(self, upsample_factors):
|
||||
super(UpsampleNet, self).__init__()
|
||||
for factor in upsample_factors:
|
||||
std = math.sqrt(1 / (3 * 2 * factor))
|
||||
|
@ -60,18 +55,18 @@ class UpsampleNet(nn.LayerList):
|
|||
each layer. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, input_channels, time_steps * upsample_factors).
|
||||
Tensor: shape(batch_size, input_channels, time_steps * upsample_factor).
|
||||
If trim_conv_artifact is True, the output time steps is less
|
||||
than time_steps * upsample_factors.
|
||||
"""
|
||||
x = paddle.unsqueeze(x, 1)
|
||||
x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
|
||||
for layer in self:
|
||||
x = layer(x)
|
||||
if trim_conv_artifact:
|
||||
time_cutoff = layer._kernel_size[1] - layer._stride[1]
|
||||
x = x[:, :, :, -time_cutoff]
|
||||
x = x[:, :, :, :-time_cutoff]
|
||||
x = F.leaky_relu(x, 0.4)
|
||||
x = paddle.squeeze(x, 1)
|
||||
x = paddle.squeeze(x, 1) # back to (B, C, T)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -125,6 +120,7 @@ class ResidualBlock(nn.Layer):
|
|||
res (Tensor): shape(batch_size, channel, height, width), the residual output.
|
||||
res (Tensor): shape(batch_size, channel, height, width), the skip output.
|
||||
"""
|
||||
x_in = x
|
||||
x = self.conv(x)
|
||||
x += self.condition_proj(condition)
|
||||
|
||||
|
@ -133,7 +129,7 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
x = self.out_proj(x)
|
||||
res, skip = paddle.chunk(x, 2, axis=1)
|
||||
return res, skip
|
||||
return x_in + res, skip
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
|
||||
|
@ -156,11 +152,16 @@ class ResidualBlock(nn.Layer):
|
|||
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
|
||||
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
|
||||
"""
|
||||
x_row_in = x_row
|
||||
if self._conv_buffer is None:
|
||||
self._init_buffer(x_row)
|
||||
self._update_buffer(x_row)
|
||||
|
||||
rw = self.rw
|
||||
# call self.conv's weight norm hook expliccitly since its __call__
|
||||
# method is not called here
|
||||
for hook in self.conv._forward_pre_hooks.values():
|
||||
hook(self.conv, self._conv_buffer)
|
||||
x_row = F.conv2d(
|
||||
self._conv_buffer,
|
||||
self.conv.weight,
|
||||
|
@ -174,7 +175,7 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
x_row = self.out_proj(x_row)
|
||||
res, skip = paddle.chunk(x_row, 2, axis=1)
|
||||
return res, skip
|
||||
return x_row_in + res, skip
|
||||
|
||||
def _init_buffer(self, input):
|
||||
batch_size, channels, _, width = input.shape
|
||||
|
@ -262,7 +263,7 @@ class Flow(nn.Layer):
|
|||
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
||||
super(Flow, self).__init__()
|
||||
# input projection
|
||||
self.first_conv = nn.utils.weight_norm(
|
||||
self.input_proj = nn.utils.weight_norm(
|
||||
nn.Conv2D(1, channels, (1, 1),
|
||||
weight_attr=I.Uniform(-1., 1.),
|
||||
bias_attr=I.Uniform(-1., 1.)))
|
||||
|
@ -272,18 +273,17 @@ class Flow(nn.Layer):
|
|||
self.dilations_dict[n_group])
|
||||
|
||||
# output projection
|
||||
self.last_conv = nn.utils.weight_norm(
|
||||
nn.Conv2D(channels, 2, (1, 1),
|
||||
weight_attr=I.Constant(0.),
|
||||
bias_attr=I.Constant(0.)))
|
||||
self.output_proj = nn.Conv2D(channels, 2, (1, 1),
|
||||
weight_attr=I.Constant(0.),
|
||||
bias_attr=I.Constant(0.))
|
||||
|
||||
# specs
|
||||
self.n_group = n_group
|
||||
|
||||
def _predict_parameters(self, x, condition):
|
||||
x = self.first_conv(x)
|
||||
x = self.input_proj(x)
|
||||
x = self.resnet(x, condition)
|
||||
bijection_params = self.last_conv(x)
|
||||
bijection_params = self.output_proj(x)
|
||||
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||
return logs, b
|
||||
|
||||
|
@ -314,14 +314,14 @@ class Flow(nn.Layer):
|
|||
return z, (logs, b)
|
||||
|
||||
def _predict_row_parameters(self, x_row, condition_row):
|
||||
x_row = self.first_conv(x_row)
|
||||
x_row = self.input_proj(x_row)
|
||||
x_row = self.resnet.add_input(x_row, condition_row)
|
||||
bijection_params = self.last_conv(x_row)
|
||||
bijection_params = self.output_proj(x_row)
|
||||
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||
return logs, b
|
||||
|
||||
def _inverse_transform_row(self, z_row, logs, b):
|
||||
x_row = (z_row - b) / paddle.exp(logs)
|
||||
x_row = (z_row - b) * paddle.exp(-logs)
|
||||
return x_row
|
||||
|
||||
def _inverse_row(self, z_row, x_row, condition_row):
|
||||
|
@ -329,7 +329,7 @@ class Flow(nn.Layer):
|
|||
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
||||
return x_next_row, (logs, b)
|
||||
|
||||
def start_sequence(self):
|
||||
def _start_sequence(self):
|
||||
self.resnet.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
|
@ -352,9 +352,9 @@ class Flow(nn.Layer):
|
|||
b_list = []
|
||||
x.append(z_0)
|
||||
|
||||
self.start_sequence()
|
||||
self._start_sequence()
|
||||
for i in range(1, self.n_group):
|
||||
x_row = x[-1] # actuallt i-1
|
||||
x_row = x[-1] # actuallt i-1:i
|
||||
z_row = z[:, :, i:i+1, :]
|
||||
condition_row = condition[:, :, i:i+1, :]
|
||||
|
||||
|
@ -368,7 +368,7 @@ class Flow(nn.Layer):
|
|||
b = paddle.concat(b_list, 2)
|
||||
return x, (logs, b)
|
||||
|
||||
|
||||
|
||||
class WaveFlow(nn.LayerList):
|
||||
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
|
||||
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
||||
|
@ -443,30 +443,24 @@ class WaveFlow(nn.LayerList):
|
|||
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
|
||||
return z, log_det_jacobian
|
||||
|
||||
def start_sequence(self):
|
||||
for layer in self:
|
||||
layer.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
||||
and transform the sample. It is a auto regressive transformation.
|
||||
|
||||
Args:
|
||||
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
|
||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
||||
z (Tensor): shape(batch, 1, time_steps), a input sample of the distribution p(Z).
|
||||
condition (Tensor): shape(batch, condition_channel, time_steps), the local condition.
|
||||
|
||||
Returns:
|
||||
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
||||
"""
|
||||
self.start_sequence()
|
||||
|
||||
z, condition = self._trim(z, condition)
|
||||
# to (B, C, h, T//h) layout
|
||||
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
||||
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
||||
|
||||
|
||||
# reverse it flow by flow
|
||||
self.n_flows
|
||||
for i in reversed(range(self.n_flows)):
|
||||
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
|
@ -478,9 +472,29 @@ class WaveFlow(nn.LayerList):
|
|||
return x
|
||||
|
||||
|
||||
class ConditionalWaveFlow(nn.LayerList):
|
||||
def __init__(self, encoder, decoder):
|
||||
super(ConditionalWaveFlow, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, audio, mel):
|
||||
condition = self.encoder(mel)
|
||||
z, log_det_jacobian = self.decoder(audio, condition)
|
||||
return z, log_det_jacobian
|
||||
|
||||
@paddle.no_grad()
|
||||
def synthesize(self, mel):
|
||||
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
||||
x = self.decoder.inverse(z, condition)
|
||||
return x
|
||||
|
||||
|
||||
class WaveFlowLoss(nn.Layer):
|
||||
def __init__(self, sigma=1.0):
|
||||
super().__init__()
|
||||
super(WaveFlowLoss, self).__init__()
|
||||
self.sigma = sigma
|
||||
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
||||
|
||||
|
@ -489,4 +503,4 @@ class WaveFlowLoss(nn.Layer):
|
|||
|
||||
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
|
||||
loss = loss / np.prod(z.shape)
|
||||
return loss + self.const
|
||||
return loss + self.const
|
||||
|
|
|
@ -83,7 +83,7 @@ class TestFlow(unittest.TestCase):
|
|||
def test_inverse_row(self):
|
||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||
net.eval()
|
||||
net.start_sequence()
|
||||
net._start_sequence()
|
||||
|
||||
x_row = paddle.randn([4, 1, 1, 32]) # last row
|
||||
condition_row = paddle.randn([4, 7, 1, 32])
|
||||
|
@ -97,7 +97,6 @@ class TestFlow(unittest.TestCase):
|
|||
def test_inverse(self):
|
||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||
net.eval()
|
||||
net.start_sequence()
|
||||
|
||||
z = paddle.randn([4, 1, 8, 32])
|
||||
condition = paddle.randn([4, 7, 8, 32])
|
||||
|
@ -129,6 +128,3 @@ class TestWaveFlow(unittest.TestCase):
|
|||
with paddle.no_grad():
|
||||
x = net.inverse(z, condition)
|
||||
self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue