waveflow refactor: add prediction functionalities

This commit is contained in:
chenfeiyu 2020-11-04 19:31:36 +08:00
parent 8094578f6d
commit e07441c193
2 changed files with 215 additions and 20 deletions

View File

@ -95,6 +95,9 @@ class ResidualBlock(nn.Layer):
weight_attr=init,
bias_attr=init)
self.conv = nn.utils.weight_norm(conv)
self.rh = rh
self.rw = rw
self.dilations = dilations
# condition projection
std = math.sqrt(1 / cond_channels)
@ -121,6 +124,41 @@ class ResidualBlock(nn.Layer):
res, skip = paddle.chunk(x, 2, axis=1)
return res, skip
def start_sequence(self):
if self.training:
raise ValueError("Only use start sequence at evaluation mode.")
self._conv_buffer = None
def add_input(self, x_row, condition_row):
if self._conv_buffer is None:
self._init_buffer(x_row)
self._update_buffer(x_row)
rw = self.rw
x_row = F.conv2d(
self._conv_buffer,
self.conv.weight,
self.conv.bias,
padding=[0, 0, rw // 2, (rw - 1) // 2],
dilation=self.dilations)
x_row += self.condition_proj(condition_row)
content, gate = paddle.chunk(x_row, 2, axis=1)
x_row = paddle.tanh(content) * F.sigmoid(gate)
x_row = self.out_proj(x_row)
res, skip = paddle.chunk(x_row, 2, axis=1)
return res, skip
def _init_buffer(self, input):
batch_size, channels, _, width = input.shape
self._conv_buffer = paddle.zeros(
[batch_size, channels, self.rh, width], dtype=input.dtype)
def _update_buffer(self, input):
self._conv_buffer = paddle.concat(
[self._conv_buffer[:, :, 1:, :], input], axis=2)
class ResidualNet(nn.LayerList):
"""
@ -144,6 +182,19 @@ class ResidualNet(nn.LayerList):
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
return out
def start_sequence(self):
for layer in self:
layer.start_sequence()
def add_input(self, x_row, condition_row):
# in reversed order
skip_connections = []
for layer in self:
x_row, skip = layer.add_input(x_row, condition_row)
skip_connections.append(skip)
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
return out
class Flow(nn.Layer):
"""
@ -175,12 +226,74 @@ class Flow(nn.Layer):
nn.Conv2D(channels, 2, (1, 1),
weight_attr=I.Constant(0.),
bias_attr=I.Constant(0.)))
# specs
self.n_group = n_group
def _predict_parameters(self, x, condition):
x = self.first_conv(x)
x = self.resnet(x, condition)
bijection_params = self.last_conv(x)
logs, b = paddle.chunk(bijection_params, 2, axis=1)
return logs, b
def _transform(self, x, logs, b):
z_0 = x[:, :, :1, :] # the first row, just copy it
z_out = x[:, :, 1:, :] * paddle.exp(logs) + b
z_out = paddle.concat([z_0, z_out], axis=2)
return z_out
def forward(self, x, condition):
# TODO(chenfeiyu): it is better to implement the transformation here
return self.last_conv(self.resnet(self.first_conv(x), condition))
# (B, C, H-1, W)
logs, b = self._predict_parameters(
x[:, :, :-1, :], condition[:, :, 1:, :])
z = self._transform(x, logs, b)
return z, (logs, b)
def _predict_row_parameters(self, x_row, condition_row):
x_row = self.first_conv(x_row)
x_row = self.resnet.add_input(x_row, condition_row)
bijection_params = self.last_conv(x_row)
logs, b = paddle.chunk(bijection_params, 2, axis=1)
return logs, b
def _inverse_transform_row(self, z_row, logs, b):
x_row = (z_row - b) / paddle.exp(logs)
return x_row
def _inverse_row(self, z_row, x_row, condition_row):
logs, b = self._predict_row_parameters(x_row, condition_row)
x_next_row = self._inverse_transform_row(z_row, logs, b)
return x_next_row, (logs, b)
def start_sequence(self):
self.resnet.start_sequence()
def inverse(self, z, condition):
z_0 = z[:, :, :1, :]
x = []
logs_list = []
b_list = []
x.append(z_0)
self.start_sequence()
for i in range(1, self.n_group):
x_row = x[-1] # actuallt i-1
z_row = z[:, :, i:i+1, :]
condition_row = condition[:, :, i:i+1, :]
x_next_row, (logs, b) = self._inverse_row(z_row, x_row, condition_row)
x.append(x_next_row)
logs_list.append(logs)
b_list.append(b)
x = paddle.concat(x, 2)
logs = paddle.concat(logs_list, 2)
b = paddle.concat(b_list, 2)
return x, (logs, b)
class WaveFlow(nn.LayerList):
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
if n_group % 2 or n_flows % 2:
@ -224,24 +337,15 @@ class WaveFlow(nn.LayerList):
# condition: (B, C, T) upsampled condition
x, condition = self._trim(x, condition)
# to (B, C, h, T //h) layout
# to (B, C, h, T//h) layout
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
# flows
logs_list = []
for i, layer in enumerate(self):
# shiting: z[i, j] depends only on x[<i, :]
input = x[:, :, :-1, :]
cond = condition[:, :, 1:, :]
output = layer(input, cond)
logs, b = paddle.chunk(output, 2, axis=1)
x, (logs, b) = layer(x, condition)
logs_list.append(logs)
# the transformation
x_0 = x[:, :, :1, :] # the first row, just copy it
x_out = x[:, :, 1:, :] * paddle.exp(logs) + b
x = paddle.concat([x_0, x_out], axis=2)
# permute paddle has no shuffle dim
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
@ -249,5 +353,26 @@ class WaveFlow(nn.LayerList):
z = paddle.squeeze(x, 1)
return z, logs_list
def start_sequence(self):
for layer in self:
layer.start_sequence()
# TODO(chenfeiyu): WaveFlowLoss
def inverse(self, z, condition):
self.start_sequence()
z, condition = self._trim(z, condition)
# to (B, C, h, T//h) layout
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
# reverse it flow by flow
self.n_flows
for i in reversed(range(self.n_flows)):
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z, (logs, b) = self[i].inverse(z, condition)
x = paddle.squeeze(z, 1)
return x
# TODO(chenfeiyu): WaveFlowLoss

View File

@ -35,6 +35,18 @@ class TestResidualBlock(unittest.TestCase):
res, skip = net(x, condition)
self.assertTupleEqual(res.numpy().shape, (4, 4, 16, 32))
self.assertTupleEqual(skip.numpy().shape, (4, 4, 16, 32))
def test_add_input(self):
net = waveflow.ResidualBlock(4, 6, (3, 3), (2, 2))
net.eval()
net.start_sequence()
x_row = paddle.randn([4, 4, 1, 32])
condition_row = paddle.randn([4, 6, 1, 32])
res, skip = net.add_input(x_row, condition_row)
self.assertTupleEqual(res.numpy().shape, (4, 4, 1, 32))
self.assertTupleEqual(skip.numpy().shape, (4, 4, 1, 32))
class TestResidualNet(unittest.TestCase):
@ -44,21 +56,79 @@ class TestResidualNet(unittest.TestCase):
condition = paddle.randn([4, 8, 8, 32])
y = net(x, condition)
self.assertTupleEqual(y.numpy().shape, (4, 6, 8, 32))
def test_add_input(self):
net = waveflow.ResidualNet(8, 6, 8, (3, 3), [1, 1, 1, 1, 1, 1, 1, 1])
net.eval()
net.start_sequence()
x_row = paddle.randn([4, 6, 1, 32])
condition_row = paddle.randn([4, 8, 1, 32])
y_row = net.add_input(x_row, condition_row)
self.assertTupleEqual(y_row.numpy().shape, (4, 6, 1, 32))
class TestFlow(unittest.TestCase):
def test_io(self):
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
x = paddle.randn([4, 1, 8, 32])
condition = paddle.randn([4, 7, 8, 32])
z, (logs, b) = net(x, condition)
self.assertTupleEqual(z.numpy().shape, (4, 1, 8, 32))
self.assertTupleEqual(logs.numpy().shape, (4, 1, 7, 32))
self.assertTupleEqual(b.numpy().shape, (4, 1, 7, 32))
def test_inverse_row(self):
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
y = net(x, condition)
self.assertTupleEqual(y.numpy().shape, (4, 2, 8, 32))
class TestWaveflow(unittest.TestCase):
net.eval()
net.start_sequence()
x_row = paddle.randn([4, 1, 1, 32]) # last row
condition_row = paddle.randn([4, 7, 1, 32])
z_row = paddle.randn([4, 1, 1, 32])
x_next_row, (logs, b) = net._inverse_row(z_row, x_row, condition_row)
self.assertTupleEqual(x_next_row.numpy().shape, (4, 1, 1, 32))
self.assertTupleEqual(logs.numpy().shape, (4, 1, 1, 32))
self.assertTupleEqual(b.numpy().shape, (4, 1, 1, 32))
def test_inverse(self):
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
net.eval()
net.start_sequence()
z = paddle.randn([4, 1, 8, 32])
condition = paddle.randn([4, 7, 8, 32])
with paddle.no_grad():
x, (logs, b) = net.inverse(z, condition)
self.assertTupleEqual(x.numpy().shape, (4, 1, 8, 32))
self.assertTupleEqual(logs.numpy().shape, (4, 1, 7, 32))
self.assertTupleEqual(b.numpy().shape, (4, 1, 7, 32))
class TestWaveFlow(unittest.TestCase):
def test_io(self):
x = paddle.randn([4, 32 * 8 ])
condition = paddle.randn([4, 7, 32 * 8])
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
z, logs = net(x, condition)
z, logs_list = net(x, condition)
self.assertTupleEqual(z.numpy().shape, (4, 8, 32))
self.assertTupleEqual(logs_list[0].numpy().shape, (4, 1, 7, 32))
def test_inverse(self):
z = paddle.randn([4, 32 * 8 ])
condition = paddle.randn([4, 7, 32 * 8])
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
net.eval()
with paddle.no_grad():
x = net.inverse(z, condition)
self.assertTupleEqual(x.numpy().shape, (4, 8, 32))