update waveflow to 2.0 APIs

This commit is contained in:
chenfeiyu 2020-11-04 01:37:49 +08:00
parent 0cdad602e2
commit 8094578f6d
1 changed files with 55 additions and 32 deletions

View File

@ -29,6 +29,11 @@ def fold(x, n_group):
return paddle.reshape(x, new_shape)
class UpsampleNet(nn.LayerList):
"""
Layer to upsample mel spectrogram to the same temporal resolution with
the corresponding waveform. It consists of several conv2dtranspose layers
which perform de convolution on mel and time dimension.
"""
def __init__(self, upsample_factors: Sequence[int]):
super(UpsampleNet, self).__init__()
for factor in upsample_factors:
@ -36,7 +41,7 @@ class UpsampleNet(nn.LayerList):
init = I.Uniform(-std, std)
self.append(
nn.utils.weight_norm(
nn.ConvTranspose2d(1, 1, (3, 2 * factor),
nn.Conv2DTranspose(1, 1, (3, 2 * factor),
padding=(1, factor // 2),
stride=(1, factor),
weight_attr=init,
@ -71,39 +76,42 @@ class UpsampleNet(nn.LayerList):
class ResidualBlock(nn.Layer):
"""
ResidualBlock that merges infomation from the condition and outputs residual
and skip. It has a conv2d layer, which has causal padding in height dimension
and same paddign in width dimension.
"""
def __init__(self, channels, cond_channels, kernel_size, dilations):
super(ResidualBlock, self).__init__()
# input conv
std = math.sqrt(1 / channels * np.prod(kernel_size))
init = I.Uniform(-std, std)
conv = nn.Conv2d(channels, 2 * channels, kernel_size, dilation=dilations,
weight_attr=init, bias_attr=init)
receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)]
rh, rw = receptive_field
paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same
conv = nn.Conv2D(channels, 2 * channels, kernel_size,
padding=paddings,
dilation=dilations,
weight_attr=init,
bias_attr=init)
self.conv = nn.utils.weight_norm(conv)
# condition projection
std = math.sqrt(1 / cond_channels)
init = I.Uniform(-std, std)
condition_proj = nn.Conv2d(cond_channels, 2 * channels, (1, 1),
condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1),
weight_attr=init, bias_attr=init)
self.condition_proj = nn.utils.weight_norm(condition_proj)
# parametric residual & skip connection
std = math.sqrt(1 / channels)
init = I.Uniform(-std, std)
out_proj = nn.Conv2d(channels, 2 * channels, (1, 1),
weight_attr=init, bias_attr=init)
out_proj = nn.Conv2D(channels, 2 * channels, (1, 1),
weight_attr=init, bias_attr=init)
self.out_proj = nn.utils.weight_norm(out_proj)
# specs
self.kernel_size = self.conv._kernel_size
self.dilations = self.conv._dilation
def forward(self, x, condition):
receptive_field = tuple(
[1 + (k -1) * d for (k, d) in zip(self.kernel_size, self.dilations)])
rh, rw = receptive_field
paddings = (rh - 1, 0, (rw - 1) // 2, (rw - 1) // 2)
x = self.conv(F.pad2d(x, paddings))
x = self.conv(x)
x += self.condition_proj(condition)
content, gate = paddle.chunk(x, 2, axis=1)
@ -112,9 +120,13 @@ class ResidualBlock(nn.Layer):
x = self.out_proj(x)
res, skip = paddle.chunk(x, 2, axis=1)
return res, skip
class ResidualNet(nn.LayerList):
"""
A stack of several ResidualBlocks. It merges condition at each layer. All
skip outputs are collected.
"""
def __init__(self, n_layer, residual_channels, condition_channels, kernel_size, dilations_h):
if len(dilations_h) != n_layer:
raise ValueError("number of dilations_h should equals num of layers")
@ -131,9 +143,13 @@ class ResidualNet(nn.LayerList):
skip_connections.append(skip)
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
return out
class Flow(nn.Layer):
"""
A Layer that merges the condition and predict scale and bias given a random
variable X.
"""
dilations_dict = {
8: [1, 1, 1, 1, 1, 1, 1, 1],
16: [1, 1, 1, 1, 1, 1, 1, 1],
@ -146,7 +162,7 @@ class Flow(nn.Layer):
super(Flow, self).__init__()
# input projection
self.first_conv = nn.utils.weight_norm(
nn.Conv2d(1, channels, (1, 1),
nn.Conv2D(1, channels, (1, 1),
weight_attr=I.Uniform(-1., 1.),
bias_attr=I.Uniform(-1., 1.)))
@ -156,11 +172,12 @@ class Flow(nn.Layer):
# output projection
self.last_conv = nn.utils.weight_norm(
nn.Conv2d(channels, 2, (1, 1),
nn.Conv2D(channels, 2, (1, 1),
weight_attr=I.Constant(0.),
bias_attr=I.Constant(0.)))
def forward(self, x, condition):
# TODO(chenfeiyu): it is better to implement the transformation here
return self.last_conv(self.resnet(self.first_conv(x), condition))
@ -170,23 +187,29 @@ class WaveFlow(nn.LayerList):
raise ValueError("number of flows and number of group must be even "
"since a permutation along group among flows is used.")
super(WaveFlow, self).__init__()
for i in range(n_flows):
for _ in range(n_flows):
self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group))
# permutations in h
self.perms = self._create_perm(n_group, n_flows)
# specs
self.n_group = n_group
self.n_flows = n_flows
def _create_perm(self, n_group, n_flows):
indices = list(range(n_group))
half = n_group // 2
self.perms = []
perms = []
for i in range(n_flows):
if i < n_flows // 2:
self.perms.append(indices[::-1])
perms.append(indices[::-1])
else:
perm = list(reversed(indices[:half])) + list(reversed(indices[half:]))
self.perms.append(perm)
self.n_group = n_group
perms.append(perm)
return perms
def trim(self, x, condition):
def _trim(self, x, condition):
assert condition.shape[-1] >= x.shape[-1]
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
@ -199,9 +222,9 @@ class WaveFlow(nn.LayerList):
def forward(self, x, condition):
# x: (B, T)
# condition: (B, C, T) upsampled condition
x, condition = self.trim(x, condition)
x, condition = self._trim(x, condition)
# transpose to (B, C, h, T //h) layout
# to (B, C, h, T //h) layout
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
@ -214,15 +237,15 @@ class WaveFlow(nn.LayerList):
output = layer(input, cond)
logs, b = paddle.chunk(output, 2, axis=1)
logs_list.append(logs)
x_0 = x[:, :, :1, :] # the first row, just copy
# the transformation
x_0 = x[:, :, :1, :] # the first row, just copy it
x_out = x[:, :, 1:, :] * paddle.exp(logs) + b
x = paddle.concat([x_0, x_out], axis=2)
# permute paddle has no shuffle dim
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
z = paddle.squeeze(x, 1)
return z, logs_list