1. update code for waveflow's probability density estimation and sampling;

2. add WaveFlowLoss.
This commit is contained in:
chenfeiyu 2020-11-04 23:22:45 +08:00
parent e07441c193
commit af4da7dd9e
2 changed files with 129 additions and 15 deletions

View File

@ -77,9 +77,9 @@ class UpsampleNet(nn.LayerList):
class ResidualBlock(nn.Layer): class ResidualBlock(nn.Layer):
""" """
ResidualBlock that merges infomation from the condition and outputs residual ResidualBlock, the basic unit of ResidualNet. It has a conv2d layer, which
and skip. It has a conv2d layer, which has causal padding in height dimension has causal padding in height dimension and same paddign in width dimension.
and same paddign in width dimension. It also has projection for the condition and output.
""" """
def __init__(self, channels, cond_channels, kernel_size, dilations): def __init__(self, channels, cond_channels, kernel_size, dilations):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
@ -114,6 +114,17 @@ class ResidualBlock(nn.Layer):
self.out_proj = nn.utils.weight_norm(out_proj) self.out_proj = nn.utils.weight_norm(out_proj)
def forward(self, x, condition): def forward(self, x, condition):
"""Compute output for a whole folded sequence.
Args:
x (Tensor): shape(batch_size, channel, height, width), the input.
condition (Tensor): shape(batch_size, condition_channel, height, width),
the local condition.
Returns:
res (Tensor): shape(batch_size, channel, height, width), the residual output.
res (Tensor): shape(batch_size, channel, height, width), the skip output.
"""
x = self.conv(x) x = self.conv(x)
x += self.condition_proj(condition) x += self.condition_proj(condition)
@ -125,11 +136,26 @@ class ResidualBlock(nn.Layer):
return res, skip return res, skip
def start_sequence(self): def start_sequence(self):
"""Prepare the layer for incremental computation of causal convolution. Reset the buffer for causal convolution.
Raises:
ValueError: If not in evaluation mode.
"""
if self.training: if self.training:
raise ValueError("Only use start sequence at evaluation mode.") raise ValueError("Only use start sequence at evaluation mode.")
self._conv_buffer = None self._conv_buffer = None
def add_input(self, x_row, condition_row): def add_input(self, x_row, condition_row):
"""Compute the output for a row and update the buffer.
Args:
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
Returns:
res (Tensor): shape(batch_size, channel, 1, width), the residual output.
res (Tensor): shape(batch_size, channel, 1, width), the skip output.
"""
if self._conv_buffer is None: if self._conv_buffer is None:
self._init_buffer(x_row) self._init_buffer(x_row)
self._update_buffer(x_row) self._update_buffer(x_row)
@ -175,6 +201,17 @@ class ResidualNet(nn.LayerList):
self.append(layer) self.append(layer)
def forward(self, x, condition): def forward(self, x, condition):
"""Comput the output of given the input and the condition.
Args:
x (Tensor): shape(batch_size, channel, height, width), the input.
condition (Tensor): shape(batch_size, condition_channel, height, width),
the local condition.
Returns:
Tensor: shape(batch_size, channel, height, width), the output, which
is an aggregation of all the skip outputs.
"""
skip_connections = [] skip_connections = []
for layer in self: for layer in self:
x, skip = layer(x, condition) x, skip = layer(x, condition)
@ -183,11 +220,21 @@ class ResidualNet(nn.LayerList):
return out return out
def start_sequence(self): def start_sequence(self):
"""Prepare the layer for incremental computation."""
for layer in self: for layer in self:
layer.start_sequence() layer.start_sequence()
def add_input(self, x_row, condition_row): def add_input(self, x_row, condition_row):
# in reversed order """Compute the output for a row and update the buffer.
Args:
x_row (Tensor): shape(batch_size, channel, 1, width), a row of the input.
condition_row (Tensor): shape(batch_size, condition_channel, 1, width), a row of the input.
Returns:
Tensor: shape(batch_size, channel, 1, width), the output, which is
an aggregation of all the skip outputs.
"""
skip_connections = [] skip_connections = []
for layer in self: for layer in self:
x_row, skip = layer.add_input(x_row, condition_row) x_row, skip = layer.add_input(x_row, condition_row)
@ -198,8 +245,11 @@ class ResidualNet(nn.LayerList):
class Flow(nn.Layer): class Flow(nn.Layer):
""" """
A Layer that merges the condition and predict scale and bias given a random A bijection (Reversable layer) that transform a density of latent variables
variable X. p(Z) into a complex data distribution p(X).
It's a auto regressive flow. The `forward` method implements the probability
density estimation. The `inverse` method implements the sampling.
""" """
dilations_dict = { dilations_dict = {
8: [1, 1, 1, 1, 1, 1, 1, 1], 8: [1, 1, 1, 1, 1, 1, 1, 1],
@ -244,6 +294,19 @@ class Flow(nn.Layer):
return z_out return z_out
def forward(self, x, condition): def forward(self, x, condition):
"""Probability density estimation. It is done by inversely transform a sample
from p(X) back into a sample from p(Z).
Args:
x (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(X).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
Returns:
(z, (logs, b))
z (Tensor): shape(batch, 1, height, width), the transformed sample.
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
"""
# (B, C, H-1, W) # (B, C, H-1, W)
logs, b = self._predict_parameters( logs, b = self._predict_parameters(
x[:, :, :-1, :], condition[:, :, 1:, :]) x[:, :, :-1, :], condition[:, :, 1:, :])
@ -270,6 +333,19 @@ class Flow(nn.Layer):
self.resnet.start_sequence() self.resnet.start_sequence()
def inverse(self, z, condition): def inverse(self, z, condition):
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
and transform the sample. It is a auto regressive transformation.
Args:
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
Returns:
(x, (logs, b))
x (Tensor): shape(batch, 1, height, width), the transformed sample.
logs (Tensor): shape(batch, 1, height - 1, width), the log scale of the inverse transformation.
b (Tensor): shape(batch, 1, height - 1, width), the shift of the inverse transformation.
"""
z_0 = z[:, :, :1, :] z_0 = z[:, :, :1, :]
x = [] x = []
logs_list = [] logs_list = []
@ -290,11 +366,11 @@ class Flow(nn.Layer):
x = paddle.concat(x, 2) x = paddle.concat(x, 2)
logs = paddle.concat(logs_list, 2) logs = paddle.concat(logs_list, 2)
b = paddle.concat(b_list, 2) b = paddle.concat(b_list, 2)
return x, (logs, b) return x, (logs, b)
class WaveFlow(nn.LayerList): class WaveFlow(nn.LayerList):
"""An Deep Reversible layer that is composed of a stack of auto regressive flows.s"""
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size): def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
if n_group % 2 or n_flows % 2: if n_group % 2 or n_flows % 2:
raise ValueError("number of flows and number of group must be even " raise ValueError("number of flows and number of group must be even "
@ -333,6 +409,16 @@ class WaveFlow(nn.LayerList):
return x, condition return x, condition
def forward(self, x, condition): def forward(self, x, condition):
"""Probability density estimation.
Args:
x (Tensor): shape(batch_size, time_steps), the audio.
condition (Tensor): shape(batch_size, condition channel, time_steps), the local condition.
Returns:
z: (Tensor): shape(batch_size, time_steps), the transformed sample.
log_det_jacobian: (Tensor), shape(1,), the log determinant of the jacobian of (dz/dx).
"""
# x: (B, T) # x: (B, T)
# condition: (B, C, T) upsampled condition # condition: (B, C, T) upsampled condition
x, condition = self._trim(x, condition) x, condition = self._trim(x, condition)
@ -350,14 +436,28 @@ class WaveFlow(nn.LayerList):
x = geo.shuffle_dim(x, 2, perm=self.perms[i]) x = geo.shuffle_dim(x, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i]) condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z = paddle.squeeze(x, 1) z = paddle.squeeze(x, 1) # (B, H, W)
return z, logs_list batch_size = z.shape[0]
z = paddle.reshape(paddle.transpose(z, [0, 2, 1]), [batch_size, -1])
log_det_jacobian = paddle.sum(paddle.stack(logs_list))
return z, log_det_jacobian
def start_sequence(self): def start_sequence(self):
for layer in self: for layer in self:
layer.start_sequence() layer.start_sequence()
def inverse(self, z, condition): def inverse(self, z, condition):
"""Sampling from the the distrition p(X). It is done by sample form p(Z)
and transform the sample. It is a auto regressive transformation.
Args:
z (Tensor): shape(batch, 1, height, width), a input sample of the distribution p(Z).
condition (Tensor): shape(batch, condition_channel, height, width), the local condition.
Returns:
x: (Tensor): shape(batch_size, time_steps), the transformed sample.
"""
self.start_sequence() self.start_sequence()
z, condition = self._trim(z, condition) z, condition = self._trim(z, condition)
@ -371,8 +471,22 @@ class WaveFlow(nn.LayerList):
z = geo.shuffle_dim(z, 2, perm=self.perms[i]) z = geo.shuffle_dim(z, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i]) condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z, (logs, b) = self[i].inverse(z, condition) z, (logs, b) = self[i].inverse(z, condition)
x = paddle.squeeze(z, 1)
x = paddle.squeeze(z, 1) # (B, H, W)
batch_size = x.shape[0]
x = paddle.reshape(paddle.transpose(x, [0, 2, 1]), [batch_size, -1])
return x return x
# TODO(chenfeiyu): WaveFlowLoss class WaveFlowLoss(nn.Layer):
def __init__(self, sigma=1.0):
super().__init__()
self.sigma = sigma
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
def forward(self, model_output):
z, log_det_jacobian = model_output
loss = paddle.sum(z * z) / (2 * self.sigma * self.sigma) - log_det_jacobian
loss = loss / np.prod(z.shape)
return loss + self.const

View File

@ -114,10 +114,10 @@ class TestWaveFlow(unittest.TestCase):
x = paddle.randn([4, 32 * 8 ]) x = paddle.randn([4, 32 * 8 ])
condition = paddle.randn([4, 7, 32 * 8]) condition = paddle.randn([4, 7, 32 * 8])
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3)) net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
z, logs_list = net(x, condition) z, logs_det_jacobian = net(x, condition)
self.assertTupleEqual(z.numpy().shape, (4, 8, 32)) self.assertTupleEqual(z.numpy().shape, (4, 32 * 8))
self.assertTupleEqual(logs_list[0].numpy().shape, (4, 1, 7, 32)) self.assertTupleEqual(logs_det_jacobian.numpy().shape, (1,))
def test_inverse(self): def test_inverse(self):
z = paddle.randn([4, 32 * 8 ]) z = paddle.randn([4, 32 * 8 ])
@ -128,7 +128,7 @@ class TestWaveFlow(unittest.TestCase):
with paddle.no_grad(): with paddle.no_grad():
x = net.inverse(z, condition) x = net.inverse(z, condition)
self.assertTupleEqual(x.numpy().shape, (4, 8, 32)) self.assertTupleEqual(x.numpy().shape, (4, 32 * 8))