waveflow refactor: add prediction functionalities
This commit is contained in:
parent
8094578f6d
commit
e07441c193
|
@ -95,6 +95,9 @@ class ResidualBlock(nn.Layer):
|
|||
weight_attr=init,
|
||||
bias_attr=init)
|
||||
self.conv = nn.utils.weight_norm(conv)
|
||||
self.rh = rh
|
||||
self.rw = rw
|
||||
self.dilations = dilations
|
||||
|
||||
# condition projection
|
||||
std = math.sqrt(1 / cond_channels)
|
||||
|
@ -121,6 +124,41 @@ class ResidualBlock(nn.Layer):
|
|||
res, skip = paddle.chunk(x, 2, axis=1)
|
||||
return res, skip
|
||||
|
||||
def start_sequence(self):
|
||||
if self.training:
|
||||
raise ValueError("Only use start sequence at evaluation mode.")
|
||||
self._conv_buffer = None
|
||||
|
||||
def add_input(self, x_row, condition_row):
|
||||
if self._conv_buffer is None:
|
||||
self._init_buffer(x_row)
|
||||
self._update_buffer(x_row)
|
||||
|
||||
rw = self.rw
|
||||
x_row = F.conv2d(
|
||||
self._conv_buffer,
|
||||
self.conv.weight,
|
||||
self.conv.bias,
|
||||
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
||||
dilation=self.dilations)
|
||||
x_row += self.condition_proj(condition_row)
|
||||
|
||||
content, gate = paddle.chunk(x_row, 2, axis=1)
|
||||
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
||||
|
||||
x_row = self.out_proj(x_row)
|
||||
res, skip = paddle.chunk(x_row, 2, axis=1)
|
||||
return res, skip
|
||||
|
||||
def _init_buffer(self, input):
|
||||
batch_size, channels, _, width = input.shape
|
||||
self._conv_buffer = paddle.zeros(
|
||||
[batch_size, channels, self.rh, width], dtype=input.dtype)
|
||||
|
||||
def _update_buffer(self, input):
|
||||
self._conv_buffer = paddle.concat(
|
||||
[self._conv_buffer[:, :, 1:, :], input], axis=2)
|
||||
|
||||
|
||||
class ResidualNet(nn.LayerList):
|
||||
"""
|
||||
|
@ -144,6 +182,19 @@ class ResidualNet(nn.LayerList):
|
|||
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
||||
return out
|
||||
|
||||
def start_sequence(self):
|
||||
for layer in self:
|
||||
layer.start_sequence()
|
||||
|
||||
def add_input(self, x_row, condition_row):
|
||||
# in reversed order
|
||||
skip_connections = []
|
||||
for layer in self:
|
||||
x_row, skip = layer.add_input(x_row, condition_row)
|
||||
skip_connections.append(skip)
|
||||
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
||||
return out
|
||||
|
||||
|
||||
class Flow(nn.Layer):
|
||||
"""
|
||||
|
@ -176,9 +227,71 @@ class Flow(nn.Layer):
|
|||
weight_attr=I.Constant(0.),
|
||||
bias_attr=I.Constant(0.)))
|
||||
|
||||
# specs
|
||||
self.n_group = n_group
|
||||
|
||||
def _predict_parameters(self, x, condition):
|
||||
x = self.first_conv(x)
|
||||
x = self.resnet(x, condition)
|
||||
bijection_params = self.last_conv(x)
|
||||
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||
return logs, b
|
||||
|
||||
def _transform(self, x, logs, b):
|
||||
z_0 = x[:, :, :1, :] # the first row, just copy it
|
||||
z_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
||||
z_out = paddle.concat([z_0, z_out], axis=2)
|
||||
return z_out
|
||||
|
||||
def forward(self, x, condition):
|
||||
# TODO(chenfeiyu): it is better to implement the transformation here
|
||||
return self.last_conv(self.resnet(self.first_conv(x), condition))
|
||||
# (B, C, H-1, W)
|
||||
logs, b = self._predict_parameters(
|
||||
x[:, :, :-1, :], condition[:, :, 1:, :])
|
||||
z = self._transform(x, logs, b)
|
||||
return z, (logs, b)
|
||||
|
||||
def _predict_row_parameters(self, x_row, condition_row):
|
||||
x_row = self.first_conv(x_row)
|
||||
x_row = self.resnet.add_input(x_row, condition_row)
|
||||
bijection_params = self.last_conv(x_row)
|
||||
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||
return logs, b
|
||||
|
||||
def _inverse_transform_row(self, z_row, logs, b):
|
||||
x_row = (z_row - b) / paddle.exp(logs)
|
||||
return x_row
|
||||
|
||||
def _inverse_row(self, z_row, x_row, condition_row):
|
||||
logs, b = self._predict_row_parameters(x_row, condition_row)
|
||||
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
||||
return x_next_row, (logs, b)
|
||||
|
||||
def start_sequence(self):
|
||||
self.resnet.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
z_0 = z[:, :, :1, :]
|
||||
x = []
|
||||
logs_list = []
|
||||
b_list = []
|
||||
x.append(z_0)
|
||||
|
||||
self.start_sequence()
|
||||
for i in range(1, self.n_group):
|
||||
x_row = x[-1] # actuallt i-1
|
||||
z_row = z[:, :, i:i+1, :]
|
||||
condition_row = condition[:, :, i:i+1, :]
|
||||
|
||||
x_next_row, (logs, b) = self._inverse_row(z_row, x_row, condition_row)
|
||||
x.append(x_next_row)
|
||||
logs_list.append(logs)
|
||||
b_list.append(b)
|
||||
|
||||
x = paddle.concat(x, 2)
|
||||
logs = paddle.concat(logs_list, 2)
|
||||
b = paddle.concat(b_list, 2)
|
||||
|
||||
return x, (logs, b)
|
||||
|
||||
|
||||
class WaveFlow(nn.LayerList):
|
||||
|
@ -231,17 +344,8 @@ class WaveFlow(nn.LayerList):
|
|||
# flows
|
||||
logs_list = []
|
||||
for i, layer in enumerate(self):
|
||||
# shiting: z[i, j] depends only on x[<i, :]
|
||||
input = x[:, :, :-1, :]
|
||||
cond = condition[:, :, 1:, :]
|
||||
output = layer(input, cond)
|
||||
logs, b = paddle.chunk(output, 2, axis=1)
|
||||
x, (logs, b) = layer(x, condition)
|
||||
logs_list.append(logs)
|
||||
# the transformation
|
||||
x_0 = x[:, :, :1, :] # the first row, just copy it
|
||||
x_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
||||
x = paddle.concat([x_0, x_out], axis=2)
|
||||
|
||||
# permute paddle has no shuffle dim
|
||||
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
|
@ -249,5 +353,26 @@ class WaveFlow(nn.LayerList):
|
|||
z = paddle.squeeze(x, 1)
|
||||
return z, logs_list
|
||||
|
||||
def start_sequence(self):
|
||||
for layer in self:
|
||||
layer.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
self.start_sequence()
|
||||
|
||||
z, condition = self._trim(z, condition)
|
||||
# to (B, C, h, T//h) layout
|
||||
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
||||
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
||||
|
||||
# reverse it flow by flow
|
||||
self.n_flows
|
||||
for i in reversed(range(self.n_flows)):
|
||||
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
z, (logs, b) = self[i].inverse(z, condition)
|
||||
x = paddle.squeeze(z, 1)
|
||||
return x
|
||||
|
||||
|
||||
# TODO(chenfeiyu): WaveFlowLoss
|
|
@ -36,6 +36,18 @@ class TestResidualBlock(unittest.TestCase):
|
|||
self.assertTupleEqual(res.numpy().shape, (4, 4, 16, 32))
|
||||
self.assertTupleEqual(skip.numpy().shape, (4, 4, 16, 32))
|
||||
|
||||
def test_add_input(self):
|
||||
net = waveflow.ResidualBlock(4, 6, (3, 3), (2, 2))
|
||||
net.eval()
|
||||
net.start_sequence()
|
||||
|
||||
x_row = paddle.randn([4, 4, 1, 32])
|
||||
condition_row = paddle.randn([4, 6, 1, 32])
|
||||
|
||||
res, skip = net.add_input(x_row, condition_row)
|
||||
self.assertTupleEqual(res.numpy().shape, (4, 4, 1, 32))
|
||||
self.assertTupleEqual(skip.numpy().shape, (4, 4, 1, 32))
|
||||
|
||||
|
||||
class TestResidualNet(unittest.TestCase):
|
||||
def test_io(self):
|
||||
|
@ -45,20 +57,78 @@ class TestResidualNet(unittest.TestCase):
|
|||
y = net(x, condition)
|
||||
self.assertTupleEqual(y.numpy().shape, (4, 6, 8, 32))
|
||||
|
||||
def test_add_input(self):
|
||||
net = waveflow.ResidualNet(8, 6, 8, (3, 3), [1, 1, 1, 1, 1, 1, 1, 1])
|
||||
net.eval()
|
||||
net.start_sequence()
|
||||
|
||||
x_row = paddle.randn([4, 6, 1, 32])
|
||||
condition_row = paddle.randn([4, 8, 1, 32])
|
||||
|
||||
y_row = net.add_input(x_row, condition_row)
|
||||
self.assertTupleEqual(y_row.numpy().shape, (4, 6, 1, 32))
|
||||
|
||||
|
||||
class TestFlow(unittest.TestCase):
|
||||
def test_io(self):
|
||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||
|
||||
x = paddle.randn([4, 1, 8, 32])
|
||||
condition = paddle.randn([4, 7, 8, 32])
|
||||
z, (logs, b) = net(x, condition)
|
||||
self.assertTupleEqual(z.numpy().shape, (4, 1, 8, 32))
|
||||
self.assertTupleEqual(logs.numpy().shape, (4, 1, 7, 32))
|
||||
self.assertTupleEqual(b.numpy().shape, (4, 1, 7, 32))
|
||||
|
||||
def test_inverse_row(self):
|
||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||
y = net(x, condition)
|
||||
self.assertTupleEqual(y.numpy().shape, (4, 2, 8, 32))
|
||||
net.eval()
|
||||
net.start_sequence()
|
||||
|
||||
x_row = paddle.randn([4, 1, 1, 32]) # last row
|
||||
condition_row = paddle.randn([4, 7, 1, 32])
|
||||
z_row = paddle.randn([4, 1, 1, 32])
|
||||
x_next_row, (logs, b) = net._inverse_row(z_row, x_row, condition_row)
|
||||
|
||||
self.assertTupleEqual(x_next_row.numpy().shape, (4, 1, 1, 32))
|
||||
self.assertTupleEqual(logs.numpy().shape, (4, 1, 1, 32))
|
||||
self.assertTupleEqual(b.numpy().shape, (4, 1, 1, 32))
|
||||
|
||||
def test_inverse(self):
|
||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||
net.eval()
|
||||
net.start_sequence()
|
||||
|
||||
z = paddle.randn([4, 1, 8, 32])
|
||||
condition = paddle.randn([4, 7, 8, 32])
|
||||
|
||||
with paddle.no_grad():
|
||||
x, (logs, b) = net.inverse(z, condition)
|
||||
self.assertTupleEqual(x.numpy().shape, (4, 1, 8, 32))
|
||||
self.assertTupleEqual(logs.numpy().shape, (4, 1, 7, 32))
|
||||
self.assertTupleEqual(b.numpy().shape, (4, 1, 7, 32))
|
||||
|
||||
|
||||
class TestWaveflow(unittest.TestCase):
|
||||
class TestWaveFlow(unittest.TestCase):
|
||||
def test_io(self):
|
||||
x = paddle.randn([4, 32 * 8 ])
|
||||
condition = paddle.randn([4, 7, 32 * 8])
|
||||
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
|
||||
z, logs = net(x, condition)
|
||||
z, logs_list = net(x, condition)
|
||||
|
||||
self.assertTupleEqual(z.numpy().shape, (4, 8, 32))
|
||||
self.assertTupleEqual(logs_list[0].numpy().shape, (4, 1, 7, 32))
|
||||
|
||||
def test_inverse(self):
|
||||
z = paddle.randn([4, 32 * 8 ])
|
||||
condition = paddle.randn([4, 7, 32 * 8])
|
||||
|
||||
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
|
||||
net.eval()
|
||||
|
||||
with paddle.no_grad():
|
||||
x = net.inverse(z, condition)
|
||||
self.assertTupleEqual(x.numpy().shape, (4, 8, 32))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue