diff --git a/tests/test_pwg.py b/tests/test_pwg.py index 7f2205d..41027b5 100644 --- a/tests/test_pwg.py +++ b/tests/test_pwg.py @@ -66,9 +66,14 @@ def test_convin_upsample_net(): out2.sum().backward() print(f"torch conv_in_upsample_net backward takes {t.elapse}s.") + print("forward check") print(out.numpy()[0]) print(out2.data.cpu().numpy()[0]) + print("backward check") + print(net.conv_in.weight.numpy()[0]) + print(net2.conv_in.weight.data.cpu().numpy()[0]) + def test_residual_block(): net = ResidualBlock(dilation=9) @@ -137,14 +142,19 @@ def test_pwg_generator(): torch.cuda.synchronize() print(f"torch generator backward takes {t.elapse}s.") + print("test forward:") print(out.numpy()[0]) print(out2.data.cpu().numpy()[0]) + + print("test backward:") + print(net.first_conv.weight.numpy()[0]) + print(net2.first_conv.weight.data.cpu().numpy()[0]) # print(out.shape) def test_pwg_discriminator(): net = PWGDiscriminator() - net2 = pwgan.ParallelWaveGANDiscriminator() + net2 = pwgan.ParallelWaveGANDiscriminator().to(device) summary(net) summary(net2) for k, v in net2.named_parameters(): @@ -154,11 +164,39 @@ def test_pwg_discriminator(): else: p.set_value(v.data.cpu().numpy()) x = paddle.randn([4, 1, 180 * 256]) - y = net(x) - y2 = net2(torch.as_tensor(x.numpy())) + + synchronize() + with timer() as t: + y = net(x) + synchronize() + print(f"forward takes {t.elapse}s.") + + synchronize() + with timer() as t: + y.sum().backward() + synchronize() + print(f"backward takes {t.elapse}s.") + + x_torch = torch.as_tensor(x.numpy()).to(device) + torch.cuda.synchronize() + with timer() as t: + y2 = net2(x_torch) + torch.cuda.synchronize() + print(f"forward takes {t.elapse}s.") + + torch.cuda.synchronize() + with timer() as t: + y2.sum().backward() + torch.cuda.synchronize() + print(f"backward takes {t.elapse}s.") + + print("test forward:") print(y.numpy()[0]) print(y2.data.cpu().numpy()[0]) - print(y.shape) + + print("test backward:") + print(net.conv_layers[0].weight.numpy()[0]) + print(net2.conv_layers[0].weight.data.cpu().numpy()[0]) def test_residual_pwg_discriminator():