update waveflow to 2.0 APIs
This commit is contained in:
parent
0cdad602e2
commit
8094578f6d
|
@ -29,6 +29,11 @@ def fold(x, n_group):
|
|||
return paddle.reshape(x, new_shape)
|
||||
|
||||
class UpsampleNet(nn.LayerList):
|
||||
"""
|
||||
Layer to upsample mel spectrogram to the same temporal resolution with
|
||||
the corresponding waveform. It consists of several conv2dtranspose layers
|
||||
which perform de convolution on mel and time dimension.
|
||||
"""
|
||||
def __init__(self, upsample_factors: Sequence[int]):
|
||||
super(UpsampleNet, self).__init__()
|
||||
for factor in upsample_factors:
|
||||
|
@ -36,7 +41,7 @@ class UpsampleNet(nn.LayerList):
|
|||
init = I.Uniform(-std, std)
|
||||
self.append(
|
||||
nn.utils.weight_norm(
|
||||
nn.ConvTranspose2d(1, 1, (3, 2 * factor),
|
||||
nn.Conv2DTranspose(1, 1, (3, 2 * factor),
|
||||
padding=(1, factor // 2),
|
||||
stride=(1, factor),
|
||||
weight_attr=init,
|
||||
|
@ -71,39 +76,42 @@ class UpsampleNet(nn.LayerList):
|
|||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""
|
||||
ResidualBlock that merges infomation from the condition and outputs residual
|
||||
and skip. It has a conv2d layer, which has causal padding in height dimension
|
||||
and same paddign in width dimension.
|
||||
"""
|
||||
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
||||
super(ResidualBlock, self).__init__()
|
||||
# input conv
|
||||
std = math.sqrt(1 / channels * np.prod(kernel_size))
|
||||
init = I.Uniform(-std, std)
|
||||
conv = nn.Conv2d(channels, 2 * channels, kernel_size, dilation=dilations,
|
||||
weight_attr=init, bias_attr=init)
|
||||
receptive_field = [1 + (k - 1) * d for (k, d) in zip(kernel_size, dilations)]
|
||||
rh, rw = receptive_field
|
||||
paddings = [rh - 1, 0, rw // 2, (rw - 1) // 2] # causal & same
|
||||
conv = nn.Conv2D(channels, 2 * channels, kernel_size,
|
||||
padding=paddings,
|
||||
dilation=dilations,
|
||||
weight_attr=init,
|
||||
bias_attr=init)
|
||||
self.conv = nn.utils.weight_norm(conv)
|
||||
|
||||
# condition projection
|
||||
std = math.sqrt(1 / cond_channels)
|
||||
init = I.Uniform(-std, std)
|
||||
condition_proj = nn.Conv2d(cond_channels, 2 * channels, (1, 1),
|
||||
condition_proj = nn.Conv2D(cond_channels, 2 * channels, (1, 1),
|
||||
weight_attr=init, bias_attr=init)
|
||||
self.condition_proj = nn.utils.weight_norm(condition_proj)
|
||||
|
||||
# parametric residual & skip connection
|
||||
std = math.sqrt(1 / channels)
|
||||
init = I.Uniform(-std, std)
|
||||
out_proj = nn.Conv2d(channels, 2 * channels, (1, 1),
|
||||
weight_attr=init, bias_attr=init)
|
||||
out_proj = nn.Conv2D(channels, 2 * channels, (1, 1),
|
||||
weight_attr=init, bias_attr=init)
|
||||
self.out_proj = nn.utils.weight_norm(out_proj)
|
||||
|
||||
# specs
|
||||
self.kernel_size = self.conv._kernel_size
|
||||
self.dilations = self.conv._dilation
|
||||
|
||||
def forward(self, x, condition):
|
||||
receptive_field = tuple(
|
||||
[1 + (k -1) * d for (k, d) in zip(self.kernel_size, self.dilations)])
|
||||
rh, rw = receptive_field
|
||||
paddings = (rh - 1, 0, (rw - 1) // 2, (rw - 1) // 2)
|
||||
x = self.conv(F.pad2d(x, paddings))
|
||||
x = self.conv(x)
|
||||
x += self.condition_proj(condition)
|
||||
|
||||
content, gate = paddle.chunk(x, 2, axis=1)
|
||||
|
@ -112,9 +120,13 @@ class ResidualBlock(nn.Layer):
|
|||
x = self.out_proj(x)
|
||||
res, skip = paddle.chunk(x, 2, axis=1)
|
||||
return res, skip
|
||||
|
||||
|
||||
|
||||
|
||||
class ResidualNet(nn.LayerList):
|
||||
"""
|
||||
A stack of several ResidualBlocks. It merges condition at each layer. All
|
||||
skip outputs are collected.
|
||||
"""
|
||||
def __init__(self, n_layer, residual_channels, condition_channels, kernel_size, dilations_h):
|
||||
if len(dilations_h) != n_layer:
|
||||
raise ValueError("number of dilations_h should equals num of layers")
|
||||
|
@ -131,9 +143,13 @@ class ResidualNet(nn.LayerList):
|
|||
skip_connections.append(skip)
|
||||
out = paddle.sum(paddle.stack(skip_connections, 0), 0)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
class Flow(nn.Layer):
|
||||
"""
|
||||
A Layer that merges the condition and predict scale and bias given a random
|
||||
variable X.
|
||||
"""
|
||||
dilations_dict = {
|
||||
8: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
16: [1, 1, 1, 1, 1, 1, 1, 1],
|
||||
|
@ -146,7 +162,7 @@ class Flow(nn.Layer):
|
|||
super(Flow, self).__init__()
|
||||
# input projection
|
||||
self.first_conv = nn.utils.weight_norm(
|
||||
nn.Conv2d(1, channels, (1, 1),
|
||||
nn.Conv2D(1, channels, (1, 1),
|
||||
weight_attr=I.Uniform(-1., 1.),
|
||||
bias_attr=I.Uniform(-1., 1.)))
|
||||
|
||||
|
@ -156,11 +172,12 @@ class Flow(nn.Layer):
|
|||
|
||||
# output projection
|
||||
self.last_conv = nn.utils.weight_norm(
|
||||
nn.Conv2d(channels, 2, (1, 1),
|
||||
nn.Conv2D(channels, 2, (1, 1),
|
||||
weight_attr=I.Constant(0.),
|
||||
bias_attr=I.Constant(0.)))
|
||||
|
||||
def forward(self, x, condition):
|
||||
# TODO(chenfeiyu): it is better to implement the transformation here
|
||||
return self.last_conv(self.resnet(self.first_conv(x), condition))
|
||||
|
||||
|
||||
|
@ -170,23 +187,29 @@ class WaveFlow(nn.LayerList):
|
|||
raise ValueError("number of flows and number of group must be even "
|
||||
"since a permutation along group among flows is used.")
|
||||
super(WaveFlow, self).__init__()
|
||||
for i in range(n_flows):
|
||||
for _ in range(n_flows):
|
||||
self.append(Flow(n_layers, channels, mel_bands, kernel_size, n_group))
|
||||
|
||||
# permutations in h
|
||||
self.perms = self._create_perm(n_group, n_flows)
|
||||
|
||||
# specs
|
||||
self.n_group = n_group
|
||||
self.n_flows = n_flows
|
||||
|
||||
def _create_perm(self, n_group, n_flows):
|
||||
indices = list(range(n_group))
|
||||
half = n_group // 2
|
||||
self.perms = []
|
||||
perms = []
|
||||
for i in range(n_flows):
|
||||
if i < n_flows // 2:
|
||||
self.perms.append(indices[::-1])
|
||||
perms.append(indices[::-1])
|
||||
else:
|
||||
perm = list(reversed(indices[:half])) + list(reversed(indices[half:]))
|
||||
self.perms.append(perm)
|
||||
|
||||
self.n_group = n_group
|
||||
perms.append(perm)
|
||||
return perms
|
||||
|
||||
def trim(self, x, condition):
|
||||
def _trim(self, x, condition):
|
||||
assert condition.shape[-1] >= x.shape[-1]
|
||||
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
|
||||
|
||||
|
@ -199,9 +222,9 @@ class WaveFlow(nn.LayerList):
|
|||
def forward(self, x, condition):
|
||||
# x: (B, T)
|
||||
# condition: (B, C, T) upsampled condition
|
||||
x, condition = self.trim(x, condition)
|
||||
x, condition = self._trim(x, condition)
|
||||
|
||||
# transpose to (B, C, h, T //h) layout
|
||||
# to (B, C, h, T //h) layout
|
||||
x = paddle.unsqueeze(paddle.transpose(fold(x, self.n_group), [0, 2, 1]), 1)
|
||||
condition = paddle.transpose(fold(condition, self.n_group), [0, 1, 3, 2])
|
||||
|
||||
|
@ -214,15 +237,15 @@ class WaveFlow(nn.LayerList):
|
|||
output = layer(input, cond)
|
||||
logs, b = paddle.chunk(output, 2, axis=1)
|
||||
logs_list.append(logs)
|
||||
|
||||
x_0 = x[:, :, :1, :] # the first row, just copy
|
||||
# the transformation
|
||||
x_0 = x[:, :, :1, :] # the first row, just copy it
|
||||
x_out = x[:, :, 1:, :] * paddle.exp(logs) + b
|
||||
x = paddle.concat([x_0, x_out], axis=2)
|
||||
|
||||
# permute paddle has no shuffle dim
|
||||
x = geo.shuffle_dim(x, 2, perm=self.perms[i])
|
||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||
|
||||
|
||||
z = paddle.squeeze(x, 1)
|
||||
return z, logs_list
|
||||
|
||||
|
|
Loading…
Reference in New Issue