WIP:update hifigan

This commit is contained in:
iclementine 2021-04-15 17:23:42 +08:00
parent 68497f89a4
commit e06c6cdfe1
1 changed files with 98 additions and 26 deletions

View File

@ -9,7 +9,10 @@ from itertools import chain
class ResidualBlock1(nn.Layer):
def __init__(self, channels, kernel_size=3, dilations=(1, 3, 5)):
"""number of dilations defines the number of layers of convolutions
in the residual block"""
super().__init__()
# convolutions with dilation
self.convs1 = nn.LayerList([
weight_norm(
nn.Conv1D(channels,
@ -30,6 +33,7 @@ class ResidualBlock1(nn.Layer):
dilation=dilations[2],
padding="same"))
])
# convolutions without dilation
self.convs2 = nn.LayerList([
weight_norm(
nn.Conv1D(channels, channels, kernel_size, padding="same")),
@ -54,6 +58,7 @@ class ResidualBlock1(nn.Layer):
class ResidualBlock2(nn.Layer):
"""A simpler alternative to ResidualBlock1."""
def __init__(self, channels, kernel_size=3, dilations=(1, 3)):
super().__init__()
self.convs = nn.LayerList([
@ -75,11 +80,11 @@ class ResidualBlock2(nn.Layer):
class WavGenerator(nn.Layer):
def __init__(self,
input_size,
upsample_init_channels,
supsample_rates,
resblock_kernel_sizes,
resblock_dilation_sizes,
input_size: int,
upsample_init_channels: int,
upsample_rates: List[int],
resblock_kernel_sizes: List[int],
resblock_dilation_sizes: List[List[int]],
resblock_type="1"):
super().__init__()
self.num_upsample_layers = len(upsample_rates)
@ -89,6 +94,8 @@ class WavGenerator(nn.Layer):
self.conv_pre = weight_norm(nn.Conv1D(input_size, upsample_init_channels, 7, padding="same"))
resblock = ResidualBlock1 if resblock_type == "1" else ResidualBlock2
# Upsampling convtranspose
self.ups = nn.LayerList()
for i, u in enumerate(upsample_rates):
self.ups.append(
@ -97,7 +104,7 @@ class WavGenerator(nn.Layer):
2*u,
u,
padding=u // 2)))
# Multi-Receptive Field Fusion (MRF)
self.resblocks = nn.LayerList()
for i in range(num_upsample_layers):
ch = upsample_init_channels // (2**(i+1))
@ -135,24 +142,26 @@ class WavGenerator(nn.Layer):
class DiscriminatorP(nn.Layer):
def __init__(self, period, kernel_size=5, stride=3):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super().__init__()
self.period = period
norm_fn = spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.LayerList([
weight_norm(nn.Conv2D(1, 32, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
weight_norm(nn.Conv2D(32, 128, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
weight_norm(nn.Conv2D(128, 512, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
weight_norm(nn.Conv2D(512, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
weight_norm(nn.Conv2D(1024, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
norm_fn(nn.Conv2D(1, 32, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
norm_fn(nn.Conv2D(32, 128, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
norm_fn(nn.Conv2D(128, 512, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
norm_fn(nn.Conv2D(512, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
norm_fn(nn.Conv2D(1024, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
])
self.conv_post = weight_norm(nn.Conv2D(1024, 1, kernel_size=[3, 1], padding=[1, 0]))
def forward(seld, x):
def forward(self, x):
# (B, 1, T) -> (B, T), [multiple (B, C, T/p, p)] time scale shrinks
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, [0, n_pad], mode='reflect')
x = F.pad(x, [0, n_pad], mode='reflect', data_format="NCL")
t += n_pad
x = paddle.reshape(x, [b, c, t // self.period, self.period])
@ -192,16 +201,17 @@ class MultiPeriodDiscriminator(nn.Layer):
class DiscriminatorS(nn.Layer):
def __init__(self):
def __init__(self, use_spectral_norm=False):
super().__init__()
norm_fn = spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.LayerList([
weight_norm(nn.Conv1D(1, 128, 15, 1, padding='same')),
weight_norm(nn.Conv1D(128, 128, 41, 2, groups=4, padding='same')),
weight_norm(nn.Conv1D(128, 256, 41, 2, groups=16, padding='same')),
weight_norm(nn.Conv1D(256, 512, 41, 4, groups=16, padding='same')),
weight_norm(nn.Conv1D(512, 1024, 41, 4, groups=16, padding='same')),
weight_norm(nn.Conv1D(1024, 1024, 41, 1, groups=16, padding='same')),
weight_norm(nn.Conv1D(1024, 1024, 5, 1, padding='same'))])
norm_fn(nn.Conv1D(1, 128, 15, 1, padding='same')),
norm_fn(nn.Conv1D(128, 128, 41, 2, groups=4, padding='same')),
norm_fn(nn.Conv1D(128, 256, 41, 2, groups=16, padding='same')),
norm_fn(nn.Conv1D(256, 512, 41, 4, groups=16, padding='same')),
norm_fn(nn.Conv1D(512, 1024, 41, 4, groups=16, padding='same')),
norm_fn(nn.Conv1D(1024, 1024, 41, 1, groups=16, padding='same')),
norm_fn(nn.Conv1D(1024, 1024, 5, 1, padding='same'))])
self.conv_post = weight_norm(nn.Conv1D(1024, 1, 3, 1, padding='same'))
def forward(self, x):
@ -217,12 +227,74 @@ class DiscriminatorS(nn.Layer):
return x, fmap
def MultiScaleDiscriminator(nn.Layer):
class MultiScaleDiscriminator(nn.Layer):
def __init__(self):
super().__init__()
self.discriminators = nn.LayerList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.LayerList([
nn.AvgPool1D(4, 2, padding=2),
nn.AvgPool1D(4, 2, padding=2),
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += paddle.mean(paddle.abs(rl - gl))
return loss*2
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = paddle.mean((1-dr)**2)
g_loss = paddle.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss)
g_losses.append(g_loss)
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = paddle.mean((1-dg)**2)
gen_losses.append(l)
loss += l
return loss, gen_losses