1. update code for waveflow's probability density estimation and sampling;

2. add WaveFlowLoss.
This commit is contained in:
chenfeiyu 2020-11-04 23:22:45 +08:00
parent e07441c193
commit af4da7dd9e
2 changed files with 129 additions and 15 deletions

View File

@ -77,9 +77,9 @@ class UpsampleNet(nn.LayerList):
class ResidualBlock(nn.Layer):
"""
ResidualBlock that merges infomation from the condition and outputs residual
and skip. It has a conv2d layer, which has causal padding in height dimension
and same paddign in width dimension.
ResidualBlock, the basic unit of ResidualNet. It has a conv2d layer, which
has causal padding in height dimension and same paddign in width dimension.
It also has projection for the condition and output.
"""
def __init__(self, channels, cond_channels, kernel_size, dilations):
super(ResidualBlock, self).__init__()
@ -114,6 +114,17 @@ class ResidualBlock(nn.Layer):
self.out_proj = nn.utils.weight_norm(out_proj)
def forward(self, x, condition):
"""Compute output for a whole folded sequence.
Args:
x (Tensor): shape(batch_size, channel, height, width), the input.
condition (Tensor): shape(batch_size, condition_channel, height, width),
the local condition.
Returns:
res (Tensor): shape(batch_size, channel, height, width), the residual output.
res (Tensor): shape(batch_size, channel, height, width), the skip output.
"""
x = self.conv(x)
x += self.condition_proj(condition)
@ -125,11 +136,26 @@ class ResidualBlock(nn.Layer):
return res, skip
def start_sequence(self):
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
Raises:
ValueError: If not in evaluation mode.
"""
if self.training:
raise ValueError("Only use start sequence at evaluation mode.")
self._conv_buffer = None
def add_input(self, x_row, condition_row):
"""Compute the output for a row and update the buffer.
Args:
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
Returns:
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
"""
if self._conv_buffer is None:
self._init_buffer(x_row)
self._update_buffer(x_row)
@ -175,6 +201,17 @@ class ResidualNet(nn.LayerList):
self.append(layer)
def forward(self, x, condition):
"""Comput the output of given the input and the condition.
Args:
x (Tensor): shape(batch_size, channel, height, width), the input.
condition (Tensor): shape(batch_size, condition_channel, height, width),
the local condition.
Returns:
Tensor: shape(batch_size, channel, height, width), the output, which
is an aggregation of all the skip outputs.
"""
skip_connections = []
for layer in self:
x, skip = layer(x, condition)
@ -183,11 +220,21 @@ class ResidualNet(nn.LayerList):
return out
def start_sequence(self):
"""Prepare the layer for incremental computation."""
for layer in self:
layer.start_sequence()
def add_input(self, x_row, condition_row):
# in reversed order
"""Compute the output for a row and update the buffer.
Args:
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
Returns:
Tensor: shape(batch_size, channel, 1, width), the output, which is
an aggregation of all the skip outputs.
"""
skip_connections = []
for layer in self:
x_row, skip = layer.add_input(x_row, condition_row)
@ -198,8 +245,11 @@ class ResidualNet(nn.LayerList):
class Flow(nn.Layer):
"""
A Layer that merges the condition and predict scale and bias given a random
variable X.
A bijection (Reversable layer) that transform a density of latent variables
p(Z) into a complex data distribution p(X).
It's a auto regressive flow. The `forward` method implements the probability
density estimation. The `inverse` method implements the sampling.
"""
dilations_dict = {
8: [1, 1, 1, 1, 1, 1, 1, 1],
@ -244,6 +294,19 @@ class Flow(nn.Layer):
return z_out
def forward(self, x, condition):
"""Probability density estimation. It is done by inversely transform a sample
from p(X) back into a sample from p(Z).
Args:
x (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(X).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
Returns:
(z, (logs, b))
z (Tensor): shape(batch, 1, height, width), the transformed sample.
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
"""
# (B, C, H-1, W)
logs, b = self._predict_parameters(
x[:, :, :-1, :], condition[:, :, 1:, :])
@ -270,6 +333,19 @@ class Flow(nn.Layer):
self.resnet.start_sequence()
def inverse(self, z, condition):
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
and transform the sample. It is a auto regressive transformation.
Args:
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
Returns:
(x, (logs, b))
x (Tensor): shape(batch, 1, height, width), the transformed sample.
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
"""
z_0 = z[:, :, :1, :]
x = []
logs_list = []
@ -290,11 +366,11 @@ class Flow(nn.Layer):
x = paddle.concat(x, 2)
logs = paddle.concat(logs_list, 2)
b = paddle.concat(b_list, 2)
return x, (logs, b)
class WaveFlow(nn.LayerList):
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
if n_group % 2 or n_flows % 2:
raise ValueError("number of flows and number of group must be even "
@ -333,6 +409,16 @@ class WaveFlow(nn.LayerList):
return x, condition
def forward(self, x, condition):
"""Probability density estimation.
Args:
x (Tensor): shape(batch_size, time_steps), the audio.
condition (Tensor): shape(batch_size, condition channel, time_steps), the local condition.
Returns:
z: (Tensor): shape(batch_size, time_steps), the transformed sample.
log_det_jacobian: (Tensor), shape(1,), the log determinant of the jacobian of (dz/dx).
"""
# x: (B, T)
# condition: (B, C, T) upsampled condition
x, condition = self._trim(x, condition)
@ -350,14 +436,28 @@ class WaveFlow(nn.LayerList):
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z = paddle.squeeze(x, 1)
return z, logs_list
z = paddle.squeeze(x, 1) # (B, H, W)
batch_size = z.shape[0]
z = paddle.reshape(paddle.transpose(z, [0, 2, 1]), [batch_size, -1])
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
return z, log_det_jacobian
def start_sequence(self):
for layer in self:
layer.start_sequence()
def inverse(self, z, condition):
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
and transform the sample. It is a auto regressive transformation.
Args:
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
Returns:
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
"""
self.start_sequence()
z, condition = self._trim(z, condition)
@ -371,8 +471,22 @@ class WaveFlow(nn.LayerList):
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z, (logs, b) = self[i].inverse(z, condition)
x = paddle.squeeze(z, 1)
x = paddle.squeeze(z, 1) # (B, H, W)
batch_size = x.shape[0]
x = paddle.reshape(paddle.transpose(x, [0, 2, 1]), [batch_size, -1])
return x
# TODO(chenfeiyu): WaveFlowLoss
class WaveFlowLoss(nn.Layer):
def __init__(self, sigma=1.0):
super().__init__()
self.sigma = sigma
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
def forward(self, model_output):
z, log_det_jacobian = model_output
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
loss = loss / np.prod(z.shape)
return loss + self.const

View File

@ -114,10 +114,10 @@ class TestWaveFlow(unittest.TestCase):
x = paddle.randn([4, 32 * 8 ])
condition = paddle.randn([4, 7, 32 * 8])
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
z, logs_list = net(x, condition)
z, logs_det_jacobian = net(x, condition)
self.assertTupleEqual(z.numpy().shape, (4, 8, 32))
self.assertTupleEqual(logs_list[0].numpy().shape, (4, 1, 7, 32))
self.assertTupleEqual(z.numpy().shape, (4, 32 * 8))
self.assertTupleEqual(logs_det_jacobian.numpy().shape, (1,))
def test_inverse(self):
z = paddle.randn([4, 32 * 8 ])
@ -128,7 +128,7 @@ class TestWaveFlow(unittest.TestCase):
with paddle.no_grad():
x = net.inverse(z, condition)
self.assertTupleEqual(x.numpy().shape, (4, 8, 32))
self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))