waveflow refactor: add prediction functionalities
This commit is contained in:
parent
8094578f6d
commit
e07441c193
|
@ -95,6 +95,9 @@ class ResidualBlock(nn.Layer):
|
||||||
weight_attr=init,
|
weight_attr=init,
|
||||||
bias_attr=init)
|
bias_attr=init)
|
||||||
self.conv = nn.utils.weight_norm(conv)
|
self.conv = nn.utils.weight_norm(conv)
|
||||||
|
self.rh = rh
|
||||||
|
self.rw = rw
|
||||||
|
self.dilations = dilations
|
||||||
|
|
||||||
# condition projection
|
# condition projection
|
||||||
std = math.sqrt(1 / cond_channels)
|
std = math.sqrt(1 / cond_channels)
|
||||||
|
@ -121,6 +124,41 @@ class ResidualBlock(nn.Layer):
|
||||||
res, skip = paddle.chunk(x, 2, axis=1)
|
res, skip = paddle.chunk(x, 2, axis=1)
|
||||||
return res, skip
|
return res, skip
|
||||||
|
|
||||||
|
def start_sequence(self):
|
||||||
|
if self.training:
|
||||||
|
raise ValueError("Only use start sequence at evaluation mode.")
|
||||||
|
self._conv_buffer = None
|
||||||
|
|
||||||
|
def add_input(self, x_row, condition_row):
|
||||||
|
if self._conv_buffer is None:
|
||||||
|
self._init_buffer(x_row)
|
||||||
|
self._update_buffer(x_row)
|
||||||
|
|
||||||
|
rw = self.rw
|
||||||
|
x_row = F.conv2d(
|
||||||
|
self._conv_buffer,
|
||||||
|
self.conv.weight,
|
||||||
|
self.conv.bias,
|
||||||
|
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
||||||
|
dilation=self.dilations)
|
||||||
|
x_row += self.condition_proj(condition_row)
|
||||||
|
|
||||||
|
content, gate = paddle.chunk(x_row, 2, axis=1)
|
||||||
|
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
||||||
|
|
||||||
|
x_row = self.out_proj(x_row)
|
||||||
|
res, skip = paddle.chunk(x_row, 2, axis=1)
|
||||||
|
return res, skip
|
||||||
|
|
||||||
|
def _init_buffer(self, input):
|
||||||
|
batch_size, channels, _, width = input.shape
|
||||||
|
self._conv_buffer = paddle.zeros(
|
||||||
|
[batch_size, channels, self.rh, width], dtype=input.dtype)
|
||||||
|
|
||||||
|
def _update_buffer(self, input):
|
||||||
|
self._conv_buffer = paddle.concat(
|
||||||
|
[self._conv_buffer[:, :, 1:, :], input], axis=2)
|
||||||
|
|
||||||
|
|
||||||
class ResidualNet(nn.LayerList):
|
class ResidualNet(nn.LayerList):
|
||||||
"""
|
"""
|
||||||
|
@ -144,6 +182,19 @@ class ResidualNet(nn.LayerList):
|
||||||
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def start_sequence(self):
|
||||||
|
for layer in self:
|
||||||
|
layer.start_sequence()
|
||||||
|
|
||||||
|
def add_input(self, x_row, condition_row):
|
||||||
|
# in reversed order
|
||||||
|
skip_connections = []
|
||||||
|
for layer in self:
|
||||||
|
x_row, skip = layer.add_input(x_row, condition_row)
|
||||||
|
skip_connections.append(skip)
|
||||||
|
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Flow(nn.Layer):
|
class Flow(nn.Layer):
|
||||||
"""
|
"""
|
||||||
|
@ -175,12 +226,74 @@ class Flow(nn.Layer):
|
||||||
nn.Conv2D(channels, 2, (1, 1),
|
nn.Conv2D(channels, 2, (1, 1),
|
||||||
weight_attr=I.Constant(0.),
|
weight_attr=I.Constant(0.),
|
||||||
bias_attr=I.Constant(0.)))
|
bias_attr=I.Constant(0.)))
|
||||||
|
|
||||||
|
# specs
|
||||||
|
self.n_group = n_group
|
||||||
|
|
||||||
|
def _predict_parameters(self, x, condition):
|
||||||
|
x = self.first_conv(x)
|
||||||
|
x = self.resnet(x, condition)
|
||||||
|
bijection_params = self.last_conv(x)
|
||||||
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||||
|
return logs, b
|
||||||
|
|
||||||
|
def _transform(self, x, logs, b):
|
||||||
|
z_0 = x[:, :, :1, :] # the first row, just copy it
|
||||||
|
z_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
||||||
|
z_out = paddle.concat([z_0, z_out], axis=2)
|
||||||
|
return z_out
|
||||||
|
|
||||||
def forward(self, x, condition):
|
def forward(self, x, condition):
|
||||||
# TODO(chenfeiyu): it is better to implement the transformation here
|
# (B, C, H-1, W)
|
||||||
return self.last_conv(self.resnet(self.first_conv(x), condition))
|
logs, b = self._predict_parameters(
|
||||||
|
x[:, :, :-1, :], condition[:, :, 1:, :])
|
||||||
|
z = self._transform(x, logs, b)
|
||||||
|
return z, (logs, b)
|
||||||
|
|
||||||
|
def _predict_row_parameters(self, x_row, condition_row):
|
||||||
|
x_row = self.first_conv(x_row)
|
||||||
|
x_row = self.resnet.add_input(x_row, condition_row)
|
||||||
|
bijection_params = self.last_conv(x_row)
|
||||||
|
logs, b = paddle.chunk(bijection_params, 2, axis=1)
|
||||||
|
return logs, b
|
||||||
|
|
||||||
|
def _inverse_transform_row(self, z_row, logs, b):
|
||||||
|
x_row = (z_row - b) / paddle.exp(logs)
|
||||||
|
return x_row
|
||||||
|
|
||||||
|
def _inverse_row(self, z_row, x_row, condition_row):
|
||||||
|
logs, b = self._predict_row_parameters(x_row, condition_row)
|
||||||
|
x_next_row = self._inverse_transform_row(z_row, logs, b)
|
||||||
|
return x_next_row, (logs, b)
|
||||||
|
|
||||||
|
def start_sequence(self):
|
||||||
|
self.resnet.start_sequence()
|
||||||
|
|
||||||
|
def inverse(self, z, condition):
|
||||||
|
z_0 = z[:, :, :1, :]
|
||||||
|
x = []
|
||||||
|
logs_list = []
|
||||||
|
b_list = []
|
||||||
|
x.append(z_0)
|
||||||
|
|
||||||
|
self.start_sequence()
|
||||||
|
for i in range(1, self.n_group):
|
||||||
|
x_row = x[-1] # actuallt i-1
|
||||||
|
z_row = z[:, :, i:i+1, :]
|
||||||
|
condition_row = condition[:, :, i:i+1, :]
|
||||||
|
|
||||||
|
x_next_row, (logs, b) = self._inverse_row(z_row, x_row, condition_row)
|
||||||
|
x.append(x_next_row)
|
||||||
|
logs_list.append(logs)
|
||||||
|
b_list.append(b)
|
||||||
|
|
||||||
|
x = paddle.concat(x, 2)
|
||||||
|
logs = paddle.concat(logs_list, 2)
|
||||||
|
b = paddle.concat(b_list, 2)
|
||||||
|
|
||||||
|
return x, (logs, b)
|
||||||
|
|
||||||
|
|
||||||
class WaveFlow(nn.LayerList):
|
class WaveFlow(nn.LayerList):
|
||||||
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
def __init__(self, n_flows, n_layers, n_group, channels, mel_bands, kernel_size):
|
||||||
if n_group % 2 or n_flows % 2:
|
if n_group % 2 or n_flows % 2:
|
||||||
|
@ -224,24 +337,15 @@ class WaveFlow(nn.LayerList):
|
||||||
# condition: (B, C, T) upsampled condition
|
# condition: (B, C, T) upsampled condition
|
||||||
x, condition = self._trim(x, condition)
|
x, condition = self._trim(x, condition)
|
||||||
|
|
||||||
# to (B, C, h, T //h) layout
|
# to (B, C, h, T//h) layout
|
||||||
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
|
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
|
||||||
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
||||||
|
|
||||||
# flows
|
# flows
|
||||||
logs_list = []
|
logs_list = []
|
||||||
for i, layer in enumerate(self):
|
for i, layer in enumerate(self):
|
||||||
# shiting: z[i, j] depends only on x[<i, :]
|
x, (logs, b) = layer(x, condition)
|
||||||
input = x[:, :, :-1, :]
|
|
||||||
cond = condition[:, :, 1:, :]
|
|
||||||
output = layer(input, cond)
|
|
||||||
logs, b = paddle.chunk(output, 2, axis=1)
|
|
||||||
logs_list.append(logs)
|
logs_list.append(logs)
|
||||||
# the transformation
|
|
||||||
x_0 = x[:, :, :1, :] # the first row, just copy it
|
|
||||||
x_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
|
||||||
x = paddle.concat([x_0, x_out], axis=2)
|
|
||||||
|
|
||||||
# permute paddle has no shuffle dim
|
# permute paddle has no shuffle dim
|
||||||
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
||||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||||
|
@ -249,5 +353,26 @@ class WaveFlow(nn.LayerList):
|
||||||
z = paddle.squeeze(x, 1)
|
z = paddle.squeeze(x, 1)
|
||||||
return z, logs_list
|
return z, logs_list
|
||||||
|
|
||||||
|
def start_sequence(self):
|
||||||
|
for layer in self:
|
||||||
|
layer.start_sequence()
|
||||||
|
|
||||||
# TODO(chenfeiyu): WaveFlowLoss
|
def inverse(self, z, condition):
|
||||||
|
self.start_sequence()
|
||||||
|
|
||||||
|
z, condition = self._trim(z, condition)
|
||||||
|
# to (B, C, h, T//h) layout
|
||||||
|
z = paddle.unsqueeze(paddle.transpose(fold(z, self.n_group), [0, 2, 1]), 1)
|
||||||
|
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
||||||
|
|
||||||
|
# reverse it flow by flow
|
||||||
|
self.n_flows
|
||||||
|
for i in reversed(range(self.n_flows)):
|
||||||
|
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||||
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||||
|
z, (logs, b) = self[i].inverse(z, condition)
|
||||||
|
x = paddle.squeeze(z, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(chenfeiyu): WaveFlowLoss
|
||||||
|
|
|
@ -35,6 +35,18 @@ class TestResidualBlock(unittest.TestCase):
|
||||||
res, skip = net(x, condition)
|
res, skip = net(x, condition)
|
||||||
self.assertTupleEqual(res.numpy().shape, (4, 4, 16, 32))
|
self.assertTupleEqual(res.numpy().shape, (4, 4, 16, 32))
|
||||||
self.assertTupleEqual(skip.numpy().shape, (4, 4, 16, 32))
|
self.assertTupleEqual(skip.numpy().shape, (4, 4, 16, 32))
|
||||||
|
|
||||||
|
def test_add_input(self):
|
||||||
|
net = waveflow.ResidualBlock(4, 6, (3, 3), (2, 2))
|
||||||
|
net.eval()
|
||||||
|
net.start_sequence()
|
||||||
|
|
||||||
|
x_row = paddle.randn([4, 4, 1, 32])
|
||||||
|
condition_row = paddle.randn([4, 6, 1, 32])
|
||||||
|
|
||||||
|
res, skip = net.add_input(x_row, condition_row)
|
||||||
|
self.assertTupleEqual(res.numpy().shape, (4, 4, 1, 32))
|
||||||
|
self.assertTupleEqual(skip.numpy().shape, (4, 4, 1, 32))
|
||||||
|
|
||||||
|
|
||||||
class TestResidualNet(unittest.TestCase):
|
class TestResidualNet(unittest.TestCase):
|
||||||
|
@ -44,21 +56,79 @@ class TestResidualNet(unittest.TestCase):
|
||||||
condition = paddle.randn([4, 8, 8, 32])
|
condition = paddle.randn([4, 8, 8, 32])
|
||||||
y = net(x, condition)
|
y = net(x, condition)
|
||||||
self.assertTupleEqual(y.numpy().shape, (4, 6, 8, 32))
|
self.assertTupleEqual(y.numpy().shape, (4, 6, 8, 32))
|
||||||
|
|
||||||
|
def test_add_input(self):
|
||||||
|
net = waveflow.ResidualNet(8, 6, 8, (3, 3), [1, 1, 1, 1, 1, 1, 1, 1])
|
||||||
|
net.eval()
|
||||||
|
net.start_sequence()
|
||||||
|
|
||||||
|
x_row = paddle.randn([4, 6, 1, 32])
|
||||||
|
condition_row = paddle.randn([4, 8, 1, 32])
|
||||||
|
|
||||||
|
y_row = net.add_input(x_row, condition_row)
|
||||||
|
self.assertTupleEqual(y_row.numpy().shape, (4, 6, 1, 32))
|
||||||
|
|
||||||
|
|
||||||
class TestFlow(unittest.TestCase):
|
class TestFlow(unittest.TestCase):
|
||||||
def test_io(self):
|
def test_io(self):
|
||||||
|
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||||
|
|
||||||
x = paddle.randn([4, 1, 8, 32])
|
x = paddle.randn([4, 1, 8, 32])
|
||||||
condition = paddle.randn([4, 7, 8, 32])
|
condition = paddle.randn([4, 7, 8, 32])
|
||||||
|
z, (logs, b) = net(x, condition)
|
||||||
|
self.assertTupleEqual(z.numpy().shape, (4, 1, 8, 32))
|
||||||
|
self.assertTupleEqual(logs.numpy().shape, (4, 1, 7, 32))
|
||||||
|
self.assertTupleEqual(b.numpy().shape, (4, 1, 7, 32))
|
||||||
|
|
||||||
|
def test_inverse_row(self):
|
||||||
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||||
y = net(x, condition)
|
net.eval()
|
||||||
self.assertTupleEqual(y.numpy().shape, (4, 2, 8, 32))
|
net.start_sequence()
|
||||||
|
|
||||||
|
x_row = paddle.randn([4, 1, 1, 32]) # last row
|
||||||
class TestWaveflow(unittest.TestCase):
|
condition_row = paddle.randn([4, 7, 1, 32])
|
||||||
|
z_row = paddle.randn([4, 1, 1, 32])
|
||||||
|
x_next_row, (logs, b) = net._inverse_row(z_row, x_row, condition_row)
|
||||||
|
|
||||||
|
self.assertTupleEqual(x_next_row.numpy().shape, (4, 1, 1, 32))
|
||||||
|
self.assertTupleEqual(logs.numpy().shape, (4, 1, 1, 32))
|
||||||
|
self.assertTupleEqual(b.numpy().shape, (4, 1, 1, 32))
|
||||||
|
|
||||||
|
def test_inverse(self):
|
||||||
|
net = waveflow.Flow(8, 16, 7, (3, 3), 8)
|
||||||
|
net.eval()
|
||||||
|
net.start_sequence()
|
||||||
|
|
||||||
|
z = paddle.randn([4, 1, 8, 32])
|
||||||
|
condition = paddle.randn([4, 7, 8, 32])
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
x, (logs, b) = net.inverse(z, condition)
|
||||||
|
self.assertTupleEqual(x.numpy().shape, (4, 1, 8, 32))
|
||||||
|
self.assertTupleEqual(logs.numpy().shape, (4, 1, 7, 32))
|
||||||
|
self.assertTupleEqual(b.numpy().shape, (4, 1, 7, 32))
|
||||||
|
|
||||||
|
|
||||||
|
class TestWaveFlow(unittest.TestCase):
|
||||||
def test_io(self):
|
def test_io(self):
|
||||||
x = paddle.randn([4, 32 * 8 ])
|
x = paddle.randn([4, 32 * 8 ])
|
||||||
condition = paddle.randn([4, 7, 32 * 8])
|
condition = paddle.randn([4, 7, 32 * 8])
|
||||||
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
|
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
|
||||||
z, logs = net(x, condition)
|
z, logs_list = net(x, condition)
|
||||||
|
|
||||||
|
self.assertTupleEqual(z.numpy().shape, (4, 8, 32))
|
||||||
|
self.assertTupleEqual(logs_list[0].numpy().shape, (4, 1, 7, 32))
|
||||||
|
|
||||||
|
def test_inverse(self):
|
||||||
|
z = paddle.randn([4, 32 * 8 ])
|
||||||
|
condition = paddle.randn([4, 7, 32 * 8])
|
||||||
|
|
||||||
|
net = waveflow.WaveFlow(2, 8, 8, 16, 7, (3, 3))
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
x = net.inverse(z, condition)
|
||||||
|
self.assertTupleEqual(x.numpy().shape, (4, 8, 32))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue