WIP:update hifigan
This commit is contained in:
parent
68497f89a4
commit
e06c6cdfe1
|
@ -9,7 +9,10 @@ from itertools import chain
|
|||
|
||||
class ResidualBlock1(nn.Layer):
|
||||
def __init__(self, channels, kernel_size=3, dilations=(1, 3, 5)):
|
||||
"""number of dilations defines the number of layers of convolutions
|
||||
in the residual block"""
|
||||
super().__init__()
|
||||
# convolutions with dilation
|
||||
self.convs1 = nn.LayerList([
|
||||
weight_norm(
|
||||
nn.Conv1D(channels,
|
||||
|
@ -30,6 +33,7 @@ class ResidualBlock1(nn.Layer):
|
|||
dilation=dilations[2],
|
||||
padding="same"))
|
||||
])
|
||||
# convolutions without dilation
|
||||
self.convs2 = nn.LayerList([
|
||||
weight_norm(
|
||||
nn.Conv1D(channels, channels, kernel_size, padding="same")),
|
||||
|
@ -54,6 +58,7 @@ class ResidualBlock1(nn.Layer):
|
|||
|
||||
|
||||
class ResidualBlock2(nn.Layer):
|
||||
"""A simpler alternative to ResidualBlock1."""
|
||||
def __init__(self, channels, kernel_size=3, dilations=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.LayerList([
|
||||
|
@ -75,11 +80,11 @@ class ResidualBlock2(nn.Layer):
|
|||
|
||||
class WavGenerator(nn.Layer):
|
||||
def __init__(self,
|
||||
input_size,
|
||||
upsample_init_channels,
|
||||
supsample_rates,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
input_size: int,
|
||||
upsample_init_channels: int,
|
||||
upsample_rates: List[int],
|
||||
resblock_kernel_sizes: List[int],
|
||||
resblock_dilation_sizes: List[List[int]],
|
||||
resblock_type="1"):
|
||||
super().__init__()
|
||||
self.num_upsample_layers = len(upsample_rates)
|
||||
|
@ -89,6 +94,8 @@ class WavGenerator(nn.Layer):
|
|||
self.conv_pre = weight_norm(nn.Conv1D(input_size, upsample_init_channels, 7, padding="same"))
|
||||
|
||||
resblock = ResidualBlock1 if resblock_type == "1" else ResidualBlock2
|
||||
|
||||
# Upsampling convtranspose
|
||||
self.ups = nn.LayerList()
|
||||
for i, u in enumerate(upsample_rates):
|
||||
self.ups.append(
|
||||
|
@ -97,7 +104,7 @@ class WavGenerator(nn.Layer):
|
|||
2*u,
|
||||
u,
|
||||
padding=u // 2)))
|
||||
|
||||
# Multi-Receptive Field Fusion (MRF)
|
||||
self.resblocks = nn.LayerList()
|
||||
for i in range(num_upsample_layers):
|
||||
ch = upsample_init_channels // (2**(i+1))
|
||||
|
@ -135,24 +142,26 @@ class WavGenerator(nn.Layer):
|
|||
|
||||
|
||||
class DiscriminatorP(nn.Layer):
|
||||
def __init__(self, period, kernel_size=5, stride=3):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
norm_fn = spectral_norm if use_spectral_norm else weight_norm
|
||||
self.convs = nn.LayerList([
|
||||
weight_norm(nn.Conv2D(1, 32, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
weight_norm(nn.Conv2D(32, 128, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
weight_norm(nn.Conv2D(128, 512, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
weight_norm(nn.Conv2D(512, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
weight_norm(nn.Conv2D(1024, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
norm_fn(nn.Conv2D(1, 32, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
norm_fn(nn.Conv2D(32, 128, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
norm_fn(nn.Conv2D(128, 512, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
norm_fn(nn.Conv2D(512, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
norm_fn(nn.Conv2D(1024, 1024, [kernel_size, 1], [stride, 1], padding=[(kernel_size -1) // 2, 0])),
|
||||
])
|
||||
self.conv_post = weight_norm(nn.Conv2D(1024, 1, kernel_size=[3, 1], padding=[1, 0]))
|
||||
|
||||
def forward(seld, x):
|
||||
def forward(self, x):
|
||||
# (B, 1, T) -> (B, T), [multiple (B, C, T/p, p)] time scale shrinks
|
||||
fmap = []
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0:
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, [0, n_pad], mode='reflect')
|
||||
x = F.pad(x, [0, n_pad], mode='reflect', data_format="NCL")
|
||||
t += n_pad
|
||||
x = paddle.reshape(x, [b, c, t // self.period, self.period])
|
||||
|
||||
|
@ -192,16 +201,17 @@ class MultiPeriodDiscriminator(nn.Layer):
|
|||
|
||||
|
||||
class DiscriminatorS(nn.Layer):
|
||||
def __init__(self):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_fn = spectral_norm if use_spectral_norm else weight_norm
|
||||
self.convs = nn.LayerList([
|
||||
weight_norm(nn.Conv1D(1, 128, 15, 1, padding='same')),
|
||||
weight_norm(nn.Conv1D(128, 128, 41, 2, groups=4, padding='same')),
|
||||
weight_norm(nn.Conv1D(128, 256, 41, 2, groups=16, padding='same')),
|
||||
weight_norm(nn.Conv1D(256, 512, 41, 4, groups=16, padding='same')),
|
||||
weight_norm(nn.Conv1D(512, 1024, 41, 4, groups=16, padding='same')),
|
||||
weight_norm(nn.Conv1D(1024, 1024, 41, 1, groups=16, padding='same')),
|
||||
weight_norm(nn.Conv1D(1024, 1024, 5, 1, padding='same'))])
|
||||
norm_fn(nn.Conv1D(1, 128, 15, 1, padding='same')),
|
||||
norm_fn(nn.Conv1D(128, 128, 41, 2, groups=4, padding='same')),
|
||||
norm_fn(nn.Conv1D(128, 256, 41, 2, groups=16, padding='same')),
|
||||
norm_fn(nn.Conv1D(256, 512, 41, 4, groups=16, padding='same')),
|
||||
norm_fn(nn.Conv1D(512, 1024, 41, 4, groups=16, padding='same')),
|
||||
norm_fn(nn.Conv1D(1024, 1024, 41, 1, groups=16, padding='same')),
|
||||
norm_fn(nn.Conv1D(1024, 1024, 5, 1, padding='same'))])
|
||||
self.conv_post = weight_norm(nn.Conv1D(1024, 1, 3, 1, padding='same'))
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -217,12 +227,74 @@ class DiscriminatorS(nn.Layer):
|
|||
return x, fmap
|
||||
|
||||
|
||||
def MultiScaleDiscriminator(nn.Layer):
|
||||
class MultiScaleDiscriminator(nn.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.LayerList([
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
])
|
||||
self.meanpools = nn.LayerList([
|
||||
nn.AvgPool1D(4, 2, padding=2),
|
||||
nn.AvgPool1D(4, 2, padding=2),
|
||||
])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
y = self.meanpools[i-1](y)
|
||||
y_hat = self.meanpools[i-1](y_hat)
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += paddle.mean(paddle.abs(rl - gl))
|
||||
|
||||
return loss*2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = paddle.mean((1-dr)**2)
|
||||
g_loss = paddle.mean(dg**2)
|
||||
loss += (r_loss + g_loss)
|
||||
r_losses.append(r_loss)
|
||||
g_losses.append(g_loss)
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = paddle.mean((1-dg)**2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue