2020-10-10 15:51:54 +08:00
|
|
|
import paddle
|
|
|
|
paddle.set_default_dtype("float64")
|
|
|
|
paddle.disable_static(paddle.CPUPlace())
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from parakeet.modules import conv
|
|
|
|
|
|
|
|
class TestConv1dCell(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
|
|
self.net = conv.Conv1dCell(4, 6, 5, dilation=2)
|
|
|
|
|
|
|
|
def forward_incremental(self, x):
|
|
|
|
outs = []
|
|
|
|
self.net.start_sequence()
|
|
|
|
with paddle.no_grad():
|
|
|
|
for i in range(x.shape[-1]):
|
|
|
|
xt = x[:, :, i]
|
|
|
|
yt = self.net.add_input(xt)
|
|
|
|
outs.append(yt)
|
|
|
|
y2 = paddle.stack(outs, axis=-1)
|
|
|
|
return y2
|
|
|
|
|
|
|
|
def test_equality(self):
|
|
|
|
x = paddle.randn([2, 4, 16])
|
|
|
|
y1 = self.net(x)
|
|
|
|
|
|
|
|
self.net.eval()
|
|
|
|
y2 = self.forward_incremental(x)
|
|
|
|
|
|
|
|
np.testing.assert_allclose(y2.numpy(), y1.numpy())
|
2020-10-14 10:05:26 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestConv1dBatchNorm(unittest.TestCase):
|
2020-10-30 15:13:57 +08:00
|
|
|
def __init__(self, methodName="runTest", causal=False, channel_last=False):
|
2020-10-14 10:05:26 +08:00
|
|
|
super(TestConv1dBatchNorm, self).__init__(methodName)
|
|
|
|
self.causal = causal
|
2020-10-30 15:13:57 +08:00
|
|
|
self.channel_last = channel_last
|
2020-10-14 10:05:26 +08:00
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
k = 5
|
|
|
|
paddding = (k - 1, 0) if self.causal else ((k-1) // 2, k //2)
|
2020-10-30 15:13:57 +08:00
|
|
|
self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding,
|
|
|
|
data_format="NLC" if self.channel_last else "NCL")
|
2020-10-14 10:05:26 +08:00
|
|
|
|
|
|
|
def test_input_output(self):
|
2020-10-30 15:13:57 +08:00
|
|
|
x = paddle.randn([4, 16, 4]) if self.channel_last else paddle.randn([4, 4, 16])
|
2020-10-14 10:05:26 +08:00
|
|
|
out = self.net(x)
|
|
|
|
out_np = out.numpy()
|
2020-10-30 15:13:57 +08:00
|
|
|
if self.channel_last:
|
|
|
|
self.assertTupleEqual(out_np.shape, (4, 16, 6))
|
|
|
|
else:
|
|
|
|
self.assertTupleEqual(out_np.shape, (4, 6, 16))
|
2020-10-14 10:05:26 +08:00
|
|
|
|
|
|
|
def runTest(self):
|
|
|
|
self.test_input_output()
|
|
|
|
|
|
|
|
|
|
|
|
def load_tests(loader, standard_tests, pattern):
|
|
|
|
suite = unittest.TestSuite()
|
2020-10-30 15:13:57 +08:00
|
|
|
suite.addTest(TestConv1dBatchNorm("runTest", True, True))
|
|
|
|
suite.addTest(TestConv1dBatchNorm("runTest", False, False))
|
|
|
|
suite.addTest(TestConv1dBatchNorm("runTest", True, False))
|
|
|
|
suite.addTest(TestConv1dBatchNorm("runTest", False, True))
|
2020-10-14 10:05:26 +08:00
|
|
|
suite.addTest(TestConv1dCell("test_equality"))
|
|
|
|
|
|
|
|
return suite
|