ParakeetEricRoss/tests/test_clarinet.py

43 lines
1.6 KiB
Python

import unittest
import numpy as np
import paddle
paddle.set_default_dtype("float64")
paddle.disable_static(paddle.CPUPlace())
from parakeet.models import clarinet
from parakeet.modules import stft
class TestParallelWaveNet(unittest.TestCase):
def test_io(self):
net = clarinet.ParallelWaveNet([8, 8, 8], [1, 1, 1], 16, 12, 2)
x = paddle.randn([4, 6073])
condition = paddle.randn([4, 12, 6073])
z, out_mu, out_log_std = net(x, condition)
self.assertTupleEqual(z.numpy().shape, (4, 6073))
self.assertTupleEqual(out_mu.numpy().shape, (4, 6073))
self.assertTupleEqual(out_log_std.numpy().shape, (4, 6073))
class TestClariNet(unittest.TestCase):
def setUp(self):
encoder = clarinet.UpsampleNet([2, 2])
teacher = clarinet.WaveNet(8, 3, 16, 3, 12, 2, "mog", -9.0)
student = clarinet.ParallelWaveNet([8, 8, 8, 8, 8, 8], [1, 1, 1, 1, 1, 1], 16, 12, 2)
stft_module = stft.STFT(16, 4, 8)
net = clarinet.Clarinet(encoder, teacher, student, stft_module, -6.0, lmd=4)
print("context size is: ", teacher.context_size)
self.net = net
def test_io(self):
audio = paddle.randn([4, 1366])
mel = paddle.randn([4, 12, 512]) # 512 * 4 =2048
audio_start = paddle.zeros([4], dtype="int64")
loss = self.net(audio, mel, audio_start, clip_kl=True)
loss["loss"].numpy()
def test_synthesis(self):
mel = paddle.randn([4, 12, 512]) # 64 = 246 / 4
out = self.net.synthesis(mel)
self.assertTupleEqual(out.numpy().shape, (4, 2048))