waveflow: explicitly call forward hook before calling a method other than forward when needed.

This commit is contained in:
chenfeiyu 2020-11-09 15:46:27 +08:00
parent af4da7dd9e
commit a9177cd6c2
2 changed files with 54 additions and 44 deletions

View File

@ -1,18 +1,13 @@
import math import math
import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
from typing import Sequence
from parakeet.modules import geometry as geo from parakeet.modules import geometry as geo
import itertools __all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
__all__ = ["WaveFlow"]
def fold(x, n_group): def fold(x, n_group):
"""Fold audio or spectrogram's temporal dimension in to groups. """Fold audio or spectrogram's temporal dimension in to groups.
@ -34,7 +29,7 @@ class UpsampleNet(nn.LayerList):
the corresponding waveform. It consists of several conv2dtranspose layers the corresponding waveform. It consists of several conv2dtranspose layers
which perform de convolution on mel and time dimension. which perform de convolution on mel and time dimension.
""" """
def __init__(self, upsample_factors: Sequence[int]): def __init__(self, upsample_factors):
super(UpsampleNet, self).__init__() super(UpsampleNet, self).__init__()
for factor in upsample_factors: for factor in upsample_factors:
std = math.sqrt(1 / (3 * 2 * factor)) std = math.sqrt(1 / (3 * 2 * factor))
@ -60,18 +55,18 @@ class UpsampleNet(nn.LayerList):
each layer. Defaults to False. each layer. Defaults to False.
Returns: Returns:
Tensor: shape(batch_size, input_channels, time_steps * upsample_factors). Tensor: shape(batch_size, input_channels, time_steps * upsample_factor).
If trim_conv_artifact is True, the output time steps is less If trim_conv_artifact is True, the output time steps is less
than time_steps * upsample_factors. than time_steps * upsample_factors.
""" """
x = paddle.unsqueeze(x, 1) x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
for layer in self: for layer in self:
x = layer(x) x = layer(x)
if trim_conv_artifact: if trim_conv_artifact:
time_cutoff = layer._kernel_size[1] - layer._stride[1] time_cutoff = layer._kernel_size[1] - layer._stride[1]
x = x[:, :, :, -time_cutoff] x = x[:, :, :, :-time_cutoff]
x = F.leaky_relu(x, 0.4) x = F.leaky_relu(x, 0.4)
x = paddle.squeeze(x, 1) x = paddle.squeeze(x, 1) # back to (B, C, T)
return x return x
@ -125,6 +120,7 @@ class ResidualBlock(nn.Layer):
res (Tensor): shape(batch_size, channel, height, width), the residual output. res (Tensor): shape(batch_size, channel, height, width), the residual output.
res (Tensor): shape(batch_size, channel, height, width), the skip output. res (Tensor): shape(batch_size, channel, height, width), the skip output.
""" """
x_in = x
x = self.conv(x) x = self.conv(x)
x += self.condition_proj(condition) x += self.condition_proj(condition)
@ -133,7 +129,7 @@ class ResidualBlock(nn.Layer):
x = self.out_proj(x) x = self.out_proj(x)
res, skip = paddle.chunk(x, 2, axis=1) res, skip = paddle.chunk(x, 2, axis=1)
return res, skip return x_in + res, skip
def start_sequence(self): def start_sequence(self):
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution. """Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
@ -156,11 +152,16 @@ class ResidualBlock(nn.Layer):
res (Tensor): shape(batch_size, channel, 1, width), the residual output. res (Tensor): shape(batch_size, channel, 1, width), the residual output.
res (Tensor): shape(batch_size, channel, 1, width), the skip output. res (Tensor): shape(batch_size, channel, 1, width), the skip output.
""" """
x_row_in = x_row
if self._conv_buffer is None: if self._conv_buffer is None:
self._init_buffer(x_row) self._init_buffer(x_row)
self._update_buffer(x_row) self._update_buffer(x_row)
rw = self.rw rw = self.rw
# call self.conv's weight norm hook expliccitly since its __call__
# method is not called here
for hook in self.conv._forward_pre_hooks.values():
hook(self.conv, self._conv_buffer)
x_row = F.conv2d( x_row = F.conv2d(
self._conv_buffer, self._conv_buffer,
self.conv.weight, self.conv.weight,
@ -174,7 +175,7 @@ class ResidualBlock(nn.Layer):
x_row = self.out_proj(x_row) x_row = self.out_proj(x_row)
res, skip = paddle.chunk(x_row, 2, axis=1) res, skip = paddle.chunk(x_row, 2, axis=1)
return res, skip return x_row_in + res, skip
def _init_buffer(self, input): def _init_buffer(self, input):
batch_size, channels, _, width = input.shape batch_size, channels, _, width = input.shape
@ -262,7 +263,7 @@ class Flow(nn.Layer):
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group): def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
super(Flow, self).__init__() super(Flow, self).__init__()
# input projection # input projection
self.first_conv = nn.utils.weight_norm( self.input_proj = nn.utils.weight_norm(
nn.Conv2D(1, channels, (1, 1), nn.Conv2D(1, channels, (1, 1),
weight_attr=I.Uniform(-1., 1.), weight_attr=I.Uniform(-1., 1.),
bias_attr=I.Uniform(-1., 1.))) bias_attr=I.Uniform(-1., 1.)))
@ -272,18 +273,17 @@ class Flow(nn.Layer):
self.dilations_dict[n_group]) self.dilations_dict[n_group])
# output projection # output projection
self.last_conv = nn.utils.weight_norm( self.output_proj = nn.Conv2D(channels, 2, (1, 1),
nn.Conv2D(channels, 2, (1, 1), weight_attr=I.Constant(0.),
weight_attr=I.Constant(0.), bias_attr=I.Constant(0.))
bias_attr=I.Constant(0.)))
# specs # specs
self.n_group = n_group self.n_group = n_group
def _predict_parameters(self, x, condition): def _predict_parameters(self, x, condition):
x = self.first_conv(x) x = self.input_proj(x)
x = self.resnet(x, condition) x = self.resnet(x, condition)
bijection_params = self.last_conv(x) bijection_params = self.output_proj(x)
logs, b = paddle.chunk(bijection_params, 2, axis=1) logs, b = paddle.chunk(bijection_params, 2, axis=1)
return logs, b return logs, b
@ -314,14 +314,14 @@ class Flow(nn.Layer):
return z, (logs, b) return z, (logs, b)
def _predict_row_parameters(self, x_row, condition_row): def _predict_row_parameters(self, x_row, condition_row):
x_row = self.first_conv(x_row) x_row = self.input_proj(x_row)
x_row = self.resnet.add_input(x_row, condition_row) x_row = self.resnet.add_input(x_row, condition_row)
bijection_params = self.last_conv(x_row) bijection_params = self.output_proj(x_row)
logs, b = paddle.chunk(bijection_params, 2, axis=1) logs, b = paddle.chunk(bijection_params, 2, axis=1)
return logs, b return logs, b
def _inverse_transform_row(self, z_row, logs, b): def _inverse_transform_row(self, z_row, logs, b):
x_row = (z_row - b) / paddle.exp(logs) x_row = (z_row - b) * paddle.exp(-logs)
return x_row return x_row
def _inverse_row(self, z_row, x_row, condition_row): def _inverse_row(self, z_row, x_row, condition_row):
@ -329,7 +329,7 @@ class Flow(nn.Layer):
x_next_row = self._inverse_transform_row(z_row, logs, b) x_next_row = self._inverse_transform_row(z_row, logs, b)
return x_next_row, (logs, b) return x_next_row, (logs, b)
def start_sequence(self): def _start_sequence(self):
self.resnet.start_sequence() self.resnet.start_sequence()
def inverse(self, z, condition): def inverse(self, z, condition):
@ -352,9 +352,9 @@ class Flow(nn.Layer):
b_list = [] b_list = []
x.append(z_0) x.append(z_0)
self.start_sequence() self._start_sequence()
for i in range(1, self.n_group): for i in range(1, self.n_group):
x_row = x[-1] # actuallt i-1 x_row = x[-1] # actuallt i-1:i
z_row = z[:, :, i:i+1, :] z_row = z[:, :, i:i+1, :]
condition_row = condition[:, :, i:i+1, :] condition_row = condition[:, :, i:i+1, :]
@ -368,7 +368,7 @@ class Flow(nn.Layer):
b = paddle.concat(b_list, 2) b = paddle.concat(b_list, 2)
return x, (logs, b) return x, (logs, b)
class WaveFlow(nn.LayerList): class WaveFlow(nn.LayerList):
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s""" """An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size): def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
@ -443,30 +443,24 @@ class WaveFlow(nn.LayerList):
log_det_jacobian = paddle.sum(paddle.stack(logs_list)) log_det_jacobian = paddle.sum(paddle.stack(logs_list))
return z, log_det_jacobian return z, log_det_jacobian
def start_sequence(self):
for layer in self:
layer.start_sequence()
def inverse(self, z, condition): def inverse(self, z, condition):
"""Sampling from the the distrition p(X). It is done by sample form p(Z) """Sampling from the the distrition p(X). It is done by sample form p(Z)
and transform the sample. It is a auto regressive transformation. and transform the sample. It is a auto regressive transformation.
Args: Args:
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z). z (Tensor): shape(batch, 1, time_steps), a input sample of the distribution p(Z).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition. condition (Tensor): shape(batch, condition_channel, time_steps), the local condition.
Returns: Returns:
x: (Tensor): shape(batch_size, time_steps), the transformed sample. x: (Tensor): shape(batch_size, time_steps), the transformed sample.
""" """
self.start_sequence()
z, condition = self._trim(z, condition) z, condition = self._trim(z, condition)
# to (B, C, h, T//h) layout # to (B, C, h, T//h) layout
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1) z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2]) condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
# reverse it flow by flow # reverse it flow by flow
self.n_flows
for i in reversed(range(self.n_flows)): for i in reversed(range(self.n_flows)):
z = geo.shuffle_dim(z, 2, perm=self.perms[i]) z = geo.shuffle_dim(z, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i]) condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
@ -478,9 +472,29 @@ class WaveFlow(nn.LayerList):
return x return x
class ConditionalWaveFlow(nn.LayerList):
def __init__(self, encoder, decoder):
super(ConditionalWaveFlow, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, audio, mel):
condition = self.encoder(mel)
z, log_det_jacobian = self.decoder(audio, condition)
return z, log_det_jacobian
@paddle.no_grad()
def synthesize(self, mel):
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
batch_size, _, time_steps = condition.shape
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
x = self.decoder.inverse(z, condition)
return x
class WaveFlowLoss(nn.Layer): class WaveFlowLoss(nn.Layer):
def __init__(self, sigma=1.0): def __init__(self, sigma=1.0):
super().__init__() super(WaveFlowLoss, self).__init__()
self.sigma = sigma self.sigma = sigma
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma) self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
@ -489,4 +503,4 @@ class WaveFlowLoss(nn.Layer):
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
loss = loss / np.prod(z.shape) loss = loss / np.prod(z.shape)
return loss + self.const return loss + self.const

View File

@ -83,7 +83,7 @@ class TestFlow(unittest.TestCase):
def test_inverse_row(self): def test_inverse_row(self):
net = waveflow.Flow(8, 16, 7, (3, 3), 8) net = waveflow.Flow(8, 16, 7, (3, 3), 8)
net.eval() net.eval()
net.start_sequence() net._start_sequence()
x_row = paddle.randn([4, 1, 1, 32]) # last row x_row = paddle.randn([4, 1, 1, 32]) # last row
condition_row = paddle.randn([4, 7, 1, 32]) condition_row = paddle.randn([4, 7, 1, 32])
@ -97,7 +97,6 @@ class TestFlow(unittest.TestCase):
def test_inverse(self): def test_inverse(self):
net = waveflow.Flow(8, 16, 7, (3, 3), 8) net = waveflow.Flow(8, 16, 7, (3, 3), 8)
net.eval() net.eval()
net.start_sequence()
z = paddle.randn([4, 1, 8, 32]) z = paddle.randn([4, 1, 8, 32])
condition = paddle.randn([4, 7, 8, 32]) condition = paddle.randn([4, 7, 8, 32])
@ -129,6 +128,3 @@ class TestWaveFlow(unittest.TestCase):
with paddle.no_grad(): with paddle.no_grad():
x = net.inverse(z, condition) x = net.inverse(z, condition)
self.assertTupleEqual(x.numpy().shape, (4, 32 * 8)) self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))