1. update code for waveflow's probability density estimation and sampling;
2. add WaveFlowLoss.
This commit is contained in:
parent
e07441c193
commit
af4da7dd9e
|
@ -77,9 +77,9 @@ class UpsampleNet(nn.LayerList):
|
|||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""
|
||||
ResidualBlock that merges infomation from the condition and outputs residual
|
||||
and skip. It has a conv2d layer, which has causal padding in height dimension
|
||||
and same paddign in width dimension.
|
||||
ResidualBlock, the basic unit of ResidualNet. It has a conv2d layer, which
|
||||
has causal padding in height dimension and same paddign in width dimension.
|
||||
It also has projection for the condition and output.
|
||||
"""
|
||||
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
@ -114,6 +114,17 @@ class ResidualBlock(nn.Layer):
|
|||
self.out_proj = nn.utils.weight_norm(out_proj)
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Compute output for a whole folded sequence.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, channel, height, width), the input.
|
||||
condition (Tensor): shape(batch_size, condition_channel, height, width),
|
||||
the local condition.
|
||||
|
||||
Returns:
|
||||
res (Tensor): shape(batch_size, channel, height, width), the residual output.
|
||||
res (Tensor): shape(batch_size, channel, height, width), the skip output.
|
||||
"""
|
||||
x = self.conv(x)
|
||||
x += self.condition_proj(condition)
|
||||
|
||||
|
@ -125,11 +136,26 @@ class ResidualBlock(nn.Layer):
|
|||
return res, skip
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
|
||||
|
||||
Raises:
|
||||
ValueError: If not in evaluation mode.
|
||||
"""
|
||||
if self.training:
|
||||
raise ValueError("Only use start sequence at evaluation mode.")
|
||||
self._conv_buffer = None
|
||||
|
||||
def add_input(self, x_row, condition_row):
|
||||
"""Compute the output for a row and update the buffer.
|
||||
|
||||
Args:
|
||||
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
|
||||
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
|
||||
|
||||
Returns:
|
||||
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
|
||||
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
|
||||
"""
|
||||
if self._conv_buffer is None:
|
||||
self._init_buffer(x_row)
|
||||
self._update_buffer(x_row)
|
||||
|
@ -175,6 +201,17 @@ class ResidualNet(nn.LayerList):
|
|||
self.append(layer)
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Comput the output of given the input and the condition.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, channel, height, width), the input.
|
||||
condition (Tensor): shape(batch_size, condition_channel, height, width),
|
||||
the local condition.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, channel, height, width), the output, which
|
||||
is an aggregation of all the skip outputs.
|
||||
"""
|
||||
skip_connections = []
|
||||
for layer in self:
|
||||
x, skip = layer(x, condition)
|
||||
|
@ -183,11 +220,21 @@ class ResidualNet(nn.LayerList):
|
|||
return out
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the layer for incremental computation."""
|
||||
for layer in self:
|
||||
layer.start_sequence()
|
||||
|
||||
def add_input(self, x_row, condition_row):
|
||||
# in reversed order
|
||||
"""Compute the output for a row and update the buffer.
|
||||
|
||||
Args:
|
||||
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
|
||||
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, channel, 1, width), the output, which is
|
||||
an aggregation of all the skip outputs.
|
||||
"""
|
||||
skip_connections = []
|
||||
for layer in self:
|
||||
x_row, skip = layer.add_input(x_row, condition_row)
|
||||
|
@ -198,8 +245,11 @@ class ResidualNet(nn.LayerList):
|
|||
|
||||
class Flow(nn.Layer):
|
||||
"""
|
||||
A Layer that merges the condition and predict scale and bias given a random
|
||||
variable X.
|
||||
A bijection (Reversable layer) that transform a density of latent variables
|
||||
p(Z) into a complex data distribution p(X).
|
||||
|
||||
It's a auto regressive flow. The `forward` method implements the probability
|
||||
density estimation. The `inverse` method implements the sampling.
|
||||
"""
|
||||
dilations_dict = {
|
||||
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
|
@ -244,6 +294,19 @@ class Flow(nn.Layer):
|
|||
return z_out
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Probability density estimation. It is done by inversely transform a sample
|
||||
from p(X) back into a sample from p(Z).
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(X).
|
||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
||||
|
||||
Returns:
|
||||
(z, (logs, b))
|
||||
z (Tensor): shape(batch, 1, height, width), the transformed sample.
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
|
||||
"""
|
||||
# (B, C, H-1, W)
|
||||
logs, b = self._predict_parameters(
|
||||
x[:, :, :-1, :], condition[:, :, 1:, :])
|
||||
|
@ -270,6 +333,19 @@ class Flow(nn.Layer):
|
|||
self.resnet.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
||||
and transform the sample. It is a auto regressive transformation.
|
||||
|
||||
Args:
|
||||
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
|
||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
||||
|
||||
Returns:
|
||||
(x, (logs, b))
|
||||
x (Tensor): shape(batch, 1, height, width), the transformed sample.
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
|
||||
"""
|
||||
z_0 = z[:, :, :1, :]
|
||||
x = []
|
||||
logs_list = []
|
||||
|
@ -290,11 +366,11 @@ class Flow(nn.Layer):
|
|||
x = paddle.concat(x, 2)
|
||||
logs = paddle.concat(logs_list, 2)
|
||||
b = paddle.concat(b_list, 2)
|
||||
|
||||
return x, (logs, b)
|
||||
|
||||
|
||||
class WaveFlow(nn.LayerList):
|
||||
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
|
||||
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
||||
if n_group % 2 or n_flows % 2:
|
||||
raise ValueError("number of flows and number of group must be even "
|
||||
|
@ -333,6 +409,16 @@ class WaveFlow(nn.LayerList):
|
|||
return x, condition
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Probability density estimation.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, time_steps), the audio.
|
||||
condition (Tensor): shape(batch_size, condition channel, time_steps), the local condition.
|
||||
|
||||
Returns:
|
||||
z: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
||||
log_det_jacobian: (Tensor), shape(1,), the log determinant of the jacobian of (dz/dx).
|
||||
"""
|
||||
# x: (B, T)
|
||||
# condition: (B, C, T) upsampled condition
|
||||
x, condition = self._trim(x, condition)
|
||||
|
@ -350,14 +436,28 @@ class WaveFlow(nn.LayerList):
|
|||
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
|
||||
z = paddle.squeeze(x, 1)
|
||||
return z, logs_list
|
||||
z = paddle.squeeze(x, 1) # (B, H, W)
|
||||
batch_size = z.shape[0]
|
||||
z = paddle.reshape(paddle.transpose(z, [0, 2, 1]), [batch_size, -1])
|
||||
|
||||
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
|
||||
return z, log_det_jacobian
|
||||
|
||||
def start_sequence(self):
|
||||
for layer in self:
|
||||
layer.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
|
||||
and transform the sample. It is a auto regressive transformation.
|
||||
|
||||
Args:
|
||||
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
|
||||
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
|
||||
|
||||
Returns:
|
||||
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
|
||||
"""
|
||||
self.start_sequence()
|
||||
|
||||
z, condition = self._trim(z, condition)
|
||||
|
@ -371,8 +471,22 @@ class WaveFlow(nn.LayerList):
|
|||
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
z, (logs, b) = self[i].inverse(z, condition)
|
||||
x = paddle.squeeze(z, 1)
|
||||
|
||||
x = paddle.squeeze(z, 1) # (B, H, W)
|
||||
batch_size = x.shape[0]
|
||||
x = paddle.reshape(paddle.transpose(x, [0, 2, 1]), [batch_size, -1])
|
||||
return x
|
||||
|
||||
|
||||
# TODO(chenfeiyu): WaveFlowLoss
|
||||
class WaveFlowLoss(nn.Layer):
|
||||
def __init__(self, sigma=1.0):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
||||
|
||||
def forward(self, model_output):
|
||||
z, log_det_jacobian = model_output
|
||||
|
||||
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
|
||||
loss = loss / np.prod(z.shape)
|
||||
return loss + self.const
|
|
@ -114,10 +114,10 @@ class TestWaveFlow(unittest.TestCase):
|
|||
x = paddle.randn([4, 32 * 8 ])
|
||||
condition = paddle.randn([4, 7, 32 * 8])
|
||||
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
|
||||
z, logs_list = net(x, condition)
|
||||
z, logs_det_jacobian = net(x, condition)
|
||||
|
||||
self.assertTupleEqual(z.numpy().shape, (4, 8, 32))
|
||||
self.assertTupleEqual(logs_list[0].numpy().shape, (4, 1, 7, 32))
|
||||
self.assertTupleEqual(z.numpy().shape, (4, 32 * 8))
|
||||
self.assertTupleEqual(logs_det_jacobian.numpy().shape, (1,))
|
||||
|
||||
def test_inverse(self):
|
||||
z = paddle.randn([4, 32 * 8 ])
|
||||
|
@ -128,7 +128,7 @@ class TestWaveFlow(unittest.TestCase):
|
|||
|
||||
with paddle.no_grad():
|
||||
x = net.inverse(z, condition)
|
||||
self.assertTupleEqual(x.numpy().shape, (4, 8, 32))
|
||||
self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue