diff --git a/parakeet/models/waveflow.py b/parakeet/models/waveflow.py index c921502..1274c47 100644 --- a/parakeet/models/waveflow.py +++ b/parakeet/models/waveflow.py @@ -1,18 +1,13 @@ import math +import numpy as np import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I -from typing import Sequence from parakeet.modules import geometry as geo -import itertools -import numpy as np -import paddle.fluid.dygraph as dg -from paddle import fluid - -__all__ = ["WaveFlow"] +__all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"] def fold(x, n_group): """Fold audio or spectrogram's temporal dimension in to groups. @@ -34,7 +29,7 @@ class UpsampleNet(nn.LayerList): the corresponding waveform. It consists of several conv2dtranspose layers which perform de convolution on mel and time dimension. """ - def __init__(self, upsample_factors: Sequence[int]): + def __init__(self, upsample_factors): super(UpsampleNet, self).__init__() for factor in upsample_factors: std = math.sqrt(1 / (3 * 2 * factor)) @@ -60,18 +55,18 @@ class UpsampleNet(nn.LayerList): each layer. Defaults to False. Returns: - Tensor: shape(batch_size, input_channels, time_steps * upsample_factors). + Tensor: shape(batch_size, input_channels, time_steps * upsample_factor). If trim_conv_artifact is True, the output time steps is less than time_steps * upsample_factors. """ - x = paddle.unsqueeze(x, 1) + x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T) for layer in self: x = layer(x) if trim_conv_artifact: time_cutoff = layer._kernel_size[1] - layer._stride[1] - x = x[:, :, :, -time_cutoff] + x = x[:, :, :, :-time_cutoff] x = F.leaky_relu(x, 0.4) - x = paddle.squeeze(x, 1) + x = paddle.squeeze(x, 1) # back to (B, C, T) return x @@ -125,6 +120,7 @@ class ResidualBlock(nn.Layer): res (Tensor): shape(batch_size, channel, height, width), the residual output. res (Tensor): shape(batch_size, channel, height, width), the skip output. """ + x_in = x x = self.conv(x) x += self.condition_proj(condition) @@ -133,7 +129,7 @@ class ResidualBlock(nn.Layer): x = self.out_proj(x) res, skip = paddle.chunk(x, 2, axis=1) - return res, skip + return x_in + res, skip def start_sequence(self): """Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution. @@ -156,11 +152,16 @@ class ResidualBlock(nn.Layer): res (Tensor): shape(batch_size, channel, 1, width), the residual output. res (Tensor): shape(batch_size, channel, 1, width), the skip output. """ + x_row_in = x_row if self._conv_buffer is None: self._init_buffer(x_row) self._update_buffer(x_row) rw = self.rw + # call self.conv's weight norm hook expliccitly since its __call__ + # method is not called here + for hook in self.conv._forward_pre_hooks.values(): + hook(self.conv, self._conv_buffer) x_row = F.conv2d( self._conv_buffer, self.conv.weight, @@ -174,7 +175,7 @@ class ResidualBlock(nn.Layer): x_row = self.out_proj(x_row) res, skip = paddle.chunk(x_row, 2, axis=1) - return res, skip + return x_row_in + res, skip def _init_buffer(self, input): batch_size, channels, _, width = input.shape @@ -262,7 +263,7 @@ class Flow(nn.Layer): def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group): super(Flow, self).__init__() # input projection - self.first_conv = nn.utils.weight_norm( + self.input_proj = nn.utils.weight_norm( nn.Conv2D(1, channels, (1, 1), weight_attr=I.Uniform(-1., 1.), bias_attr=I.Uniform(-1., 1.))) @@ -272,18 +273,17 @@ class Flow(nn.Layer): self.dilations_dict[n_group]) # output projection - self.last_conv = nn.utils.weight_norm( - nn.Conv2D(channels, 2, (1, 1), - weight_attr=I.Constant(0.), - bias_attr=I.Constant(0.))) + self.output_proj = nn.Conv2D(channels, 2, (1, 1), + weight_attr=I.Constant(0.), + bias_attr=I.Constant(0.)) # specs self.n_group = n_group def _predict_parameters(self, x, condition): - x = self.first_conv(x) + x = self.input_proj(x) x = self.resnet(x, condition) - bijection_params = self.last_conv(x) + bijection_params = self.output_proj(x) logs, b = paddle.chunk(bijection_params, 2, axis=1) return logs, b @@ -314,14 +314,14 @@ class Flow(nn.Layer): return z, (logs, b) def _predict_row_parameters(self, x_row, condition_row): - x_row = self.first_conv(x_row) + x_row = self.input_proj(x_row) x_row = self.resnet.add_input(x_row, condition_row) - bijection_params = self.last_conv(x_row) + bijection_params = self.output_proj(x_row) logs, b = paddle.chunk(bijection_params, 2, axis=1) return logs, b def _inverse_transform_row(self, z_row, logs, b): - x_row = (z_row - b) / paddle.exp(logs) + x_row = (z_row - b) * paddle.exp(-logs) return x_row def _inverse_row(self, z_row, x_row, condition_row): @@ -329,7 +329,7 @@ class Flow(nn.Layer): x_next_row = self._inverse_transform_row(z_row, logs, b) return x_next_row, (logs, b) - def start_sequence(self): + def _start_sequence(self): self.resnet.start_sequence() def inverse(self, z, condition): @@ -352,9 +352,9 @@ class Flow(nn.Layer): b_list = [] x.append(z_0) - self.start_sequence() + self._start_sequence() for i in range(1, self.n_group): - x_row = x[-1] # actuallt i-1 + x_row = x[-1] # actuallt i-1:i z_row = z[:, :, i:i+1, :] condition_row = condition[:, :, i:i+1, :] @@ -368,7 +368,7 @@ class Flow(nn.Layer): b = paddle.concat(b_list, 2) return x, (logs, b) - + class WaveFlow(nn.LayerList): """An Deep Reversible layer that is composed of a stack of auto regressive flows.s""" def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size): @@ -443,30 +443,24 @@ class WaveFlow(nn.LayerList): log_det_jacobian = paddle.sum(paddle.stack(logs_list)) return z, log_det_jacobian - def start_sequence(self): - for layer in self: - layer.start_sequence() - def inverse(self, z, condition): """Sampling from the the distrition p(X). It is done by sample form p(Z) and transform the sample. It is a auto regressive transformation. Args: - z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z). - condition (Tensor): shape(batch, condition_channel, height, width), the local condition. + z (Tensor): shape(batch, 1, time_steps), a input sample of the distribution p(Z). + condition (Tensor): shape(batch, condition_channel, time_steps), the local condition. Returns: x: (Tensor): shape(batch_size, time_steps), the transformed sample. """ - self.start_sequence() z, condition = self._trim(z, condition) # to (B, C, h, T//h) layout z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1) condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2]) - + # reverse it flow by flow - self.n_flows for i in reversed(range(self.n_flows)): z = geo.shuffle_dim(z, 2, perm=self.perms[i]) condition = geo.shuffle_dim(condition, 2, perm=self.perms[i]) @@ -478,9 +472,29 @@ class WaveFlow(nn.LayerList): return x +class ConditionalWaveFlow(nn.LayerList): + def __init__(self, encoder, decoder): + super(ConditionalWaveFlow, self).__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, audio, mel): + condition = self.encoder(mel) + z, log_det_jacobian = self.decoder(audio, condition) + return z, log_det_jacobian + + @paddle.no_grad() + def synthesize(self, mel): + condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T) + batch_size, _, time_steps = condition.shape + z = paddle.randn([batch_size, time_steps], dtype=mel.dtype) + x = self.decoder.inverse(z, condition) + return x + + class WaveFlowLoss(nn.Layer): def __init__(self, sigma=1.0): - super().__init__() + super(WaveFlowLoss, self).__init__() self.sigma = sigma self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma) @@ -489,4 +503,4 @@ class WaveFlowLoss(nn.Layer): loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian loss = loss / np.prod(z.shape) - return loss + self.const \ No newline at end of file + return loss + self.const diff --git a/tests/test_waveflow.py b/tests/test_waveflow.py index 849c7d8..15bbc44 100644 --- a/tests/test_waveflow.py +++ b/tests/test_waveflow.py @@ -83,7 +83,7 @@ class TestFlow(unittest.TestCase): def test_inverse_row(self): net = waveflow.Flow(8, 16, 7, (3, 3), 8) net.eval() - net.start_sequence() + net._start_sequence() x_row = paddle.randn([4, 1, 1, 32]) # last row condition_row = paddle.randn([4, 7, 1, 32]) @@ -97,7 +97,6 @@ class TestFlow(unittest.TestCase): def test_inverse(self): net = waveflow.Flow(8, 16, 7, (3, 3), 8) net.eval() - net.start_sequence() z = paddle.randn([4, 1, 8, 32]) condition = paddle.randn([4, 7, 8, 32]) @@ -129,6 +128,3 @@ class TestWaveFlow(unittest.TestCase): with paddle.no_grad(): x = net.inverse(z, condition) self.assertTupleEqual(x.numpy().shape, (4, 32 * 8)) - - - \ No newline at end of file