33 lines
996 B
Python
33 lines
996 B
Python
import unittest
|
|
import paddle
|
|
from paddle import nn
|
|
paddle.disable_static(paddle.CPUPlace())
|
|
paddle.set_default_dtype("float64")
|
|
|
|
from parakeet.modules import connections as conn
|
|
|
|
class TestPreLayerNormWrapper(unittest.TestCase):
|
|
def test_io(self):
|
|
net = nn.Linear(8, 8)
|
|
net = conn.PreLayerNormWrapper(net, 8)
|
|
x = paddle.randn([4, 8])
|
|
y = net(x)
|
|
self.assertTupleEqual(x.numpy().shape, y.numpy().shape)
|
|
|
|
|
|
class TestPostLayerNormWrapper(unittest.TestCase):
|
|
def test_io(self):
|
|
net = nn.Linear(8, 8)
|
|
net = conn.PostLayerNormWrapper(net, 8)
|
|
x = paddle.randn([4, 8])
|
|
y = net(x)
|
|
self.assertTupleEqual(x.numpy().shape, y.numpy().shape)
|
|
|
|
|
|
class TestResidualWrapper(unittest.TestCase):
|
|
def test_io(self):
|
|
net = nn.Linear(8, 8)
|
|
net = conn.ResidualWrapper(net)
|
|
x = paddle.randn([4, 8])
|
|
y = net(x)
|
|
self.assertTupleEqual(x.numpy().shape, y.numpy().shape) |