add numerical testing for PWGGenerator

This commit is contained in:
chenfeiyu 2021-06-21 11:17:01 +00:00
parent 83c9f0aeae
commit b5f99a925f
1 changed files with 42 additions and 4 deletions

View File

@ -66,9 +66,14 @@ def test_convin_upsample_net():
out2.sum().backward()
print(f"torch conv_in_upsample_net backward takes {t.elapse}s.")
print("forward check")
print(out.numpy()[0])
print(out2.data.cpu().numpy()[0])
print("backward check")
print(net.conv_in.weight.numpy()[0])
print(net2.conv_in.weight.data.cpu().numpy()[0])
def test_residual_block():
net = ResidualBlock(dilation=9)
@ -137,14 +142,19 @@ def test_pwg_generator():
torch.cuda.synchronize()
print(f"torch generator backward takes {t.elapse}s.")
print("test forward:")
print(out.numpy()[0])
print(out2.data.cpu().numpy()[0])
print("test backward:")
print(net.first_conv.weight.numpy()[0])
print(net2.first_conv.weight.data.cpu().numpy()[0])
# print(out.shape)
def test_pwg_discriminator():
net = PWGDiscriminator()
net2 = pwgan.ParallelWaveGANDiscriminator()
net2 = pwgan.ParallelWaveGANDiscriminator().to(device)
summary(net)
summary(net2)
for k, v in net2.named_parameters():
@ -154,11 +164,39 @@ def test_pwg_discriminator():
else:
p.set_value(v.data.cpu().numpy())
x = paddle.randn([4, 1, 180 * 256])
y = net(x)
y2 = net2(torch.as_tensor(x.numpy()))
synchronize()
with timer() as t:
y = net(x)
synchronize()
print(f"forward takes {t.elapse}s.")
synchronize()
with timer() as t:
y.sum().backward()
synchronize()
print(f"backward takes {t.elapse}s.")
x_torch = torch.as_tensor(x.numpy()).to(device)
torch.cuda.synchronize()
with timer() as t:
y2 = net2(x_torch)
torch.cuda.synchronize()
print(f"forward takes {t.elapse}s.")
torch.cuda.synchronize()
with timer() as t:
y2.sum().backward()
torch.cuda.synchronize()
print(f"backward takes {t.elapse}s.")
print("test forward:")
print(y.numpy()[0])
print(y2.data.cpu().numpy()[0])
print(y.shape)
print("test backward:")
print(net.conv_layers[0].weight.numpy()[0])
print(net2.conv_layers[0].weight.data.cpu().numpy()[0])
def test_residual_pwg_discriminator():