ParakeetRebeccaRosario/tests/test_conv.py

67 lines
2.2 KiB
Python
Raw Normal View History

2020-10-10 15:51:54 +08:00
import paddle
paddle.set_default_dtype("float64")
paddle.disable_static(paddle.CPUPlace())
import unittest
import numpy as np
from parakeet.modules import conv
class TestConv1dCell(unittest.TestCase):
def setUp(self):
self.net = conv.Conv1dCell(4, 6, 5, dilation=2)
def forward_incremental(self, x):
outs = []
self.net.start_sequence()
with paddle.no_grad():
for i in range(x.shape[-1]):
xt = x[:, :, i]
yt = self.net.add_input(xt)
outs.append(yt)
y2 = paddle.stack(outs, axis=-1)
return y2
def test_equality(self):
x = paddle.randn([2, 4, 16])
y1 = self.net(x)
self.net.eval()
y2 = self.forward_incremental(x)
np.testing.assert_allclose(y2.numpy(), y1.numpy())
2020-10-14 10:05:26 +08:00
class TestConv1dBatchNorm(unittest.TestCase):
def __init__(self, methodName="runTest", causal=False, channel_last=False):
2020-10-14 10:05:26 +08:00
super(TestConv1dBatchNorm, self).__init__(methodName)
self.causal = causal
self.channel_last = channel_last
2020-10-14 10:05:26 +08:00
def setUp(self):
k = 5
paddding = (k - 1, 0) if self.causal else ((k-1) // 2, k //2)
self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding,
data_format="NLC" if self.channel_last else "NCL")
2020-10-14 10:05:26 +08:00
def test_input_output(self):
x = paddle.randn([4, 16, 4]) if self.channel_last else paddle.randn([4, 4, 16])
2020-10-14 10:05:26 +08:00
out = self.net(x)
out_np = out.numpy()
if self.channel_last:
self.assertTupleEqual(out_np.shape, (4, 16, 6))
else:
self.assertTupleEqual(out_np.shape, (4, 6, 16))
2020-10-14 10:05:26 +08:00
def runTest(self):
self.test_input_output()
def load_tests(loader, standard_tests, pattern):
suite = unittest.TestSuite()
suite.addTest(TestConv1dBatchNorm("runTest", True, True))
suite.addTest(TestConv1dBatchNorm("runTest", False, False))
suite.addTest(TestConv1dBatchNorm("runTest", True, False))
suite.addTest(TestConv1dBatchNorm("runTest", False, True))
2020-10-14 10:05:26 +08:00
suite.addTest(TestConv1dCell("test_equality"))
return suite