move Conv1dBatchNorm to conv.py
This commit is contained in:
parent
f9087ea9a2
commit
40457227e6
|
@ -4,8 +4,10 @@ from paddle import nn
|
|||
from paddle.nn import functional as F
|
||||
|
||||
from parakeet.modules.attention import _split_heads, _concat_heads, drop_head, scaled_dot_product_attention
|
||||
from parakeet.modules.transformer import PositionwiseFFN, combine_mask
|
||||
from parakeet.modules.transformer import PositionwiseFFN
|
||||
from parakeet.modules import masking
|
||||
from parakeet.modules.cbhg import Conv1dBatchNorm
|
||||
from parakeet.modules import positional_encoding as pe
|
||||
|
||||
# Transformer TTS's own implementation of transformer
|
||||
class MultiheadAttention(nn.Layer):
|
||||
|
@ -160,7 +162,7 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
"""
|
||||
tq = q.shape[1]
|
||||
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
|
||||
combined_mask = combine_mask(decoder_mask, no_future_mask)
|
||||
combined_mask = masking.combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
|
||||
|
||||
# pre norm
|
||||
q_in = q
|
||||
|
@ -234,25 +236,57 @@ class PostNet(nn.Layer):
|
|||
self.last_norm = nn.BatchNorm1d(d_output)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
for layer in self.convs:
|
||||
x = paddle.tanh(layer(x))
|
||||
x = self.last_norm(x)
|
||||
x = self.last_norm(x + x_in)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerTTS(nn.Layer):
|
||||
def __init__(self, vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn,
|
||||
def __init__(self, vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn, positional_encoding_scalar,
|
||||
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
||||
postnet_kernel_size, reduction_factor, dropout):
|
||||
postnet_kernel_size, max_reduction_factor, dropout):
|
||||
self.encoder_prenet = nn.Embedding(vocab_size, d_model, padding_idx)
|
||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_model) # it may be extended later
|
||||
self.encoder = TransformerEncoder(d_model, n_heads, d_ffn, encoder_layers, dropout)
|
||||
|
||||
self.decoder_prenet = DecoderPreNet(d_model, d_prenet, dropout)
|
||||
self.decoder_pe = pe.positional_encoding(0, 1000, d_model) # it may be extended later
|
||||
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
|
||||
self.final_proj = nn.Linear(d_model, reduction_factor * d_mel)
|
||||
self.final_proj = nn.Linear(d_model, max_reduction_factor * d_mel)
|
||||
self.decoder_postnet = PostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
|
||||
# specs
|
||||
self.padding_idx = padding_idx
|
||||
self.d_model = d_model
|
||||
self.pe_scalar = positional_encoding_scalar
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
def forward(self, text, mel, stop):
|
||||
pass
|
||||
|
||||
|
||||
def encode(self, text):
|
||||
T_enc = text.shape[-1]
|
||||
embed = self.encoder_prenet(text)
|
||||
pe = self.encoder_pe[:T_enc, :] # (T, C)
|
||||
x = embed.scale(math.sqrt(self.d_model)) + pe.scale(self.pe_scalar)
|
||||
encoder_padding_mask = masking.id_mask(text, self.padding_idx)
|
||||
|
||||
x = F.dropout(x, training=self.training)
|
||||
x, attention_weights = self.encoder(x, encoder_padding_mask)
|
||||
return x, attention_weights, encoder_padding_mask
|
||||
|
||||
def decode(self, ):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def infer(self):
|
||||
pass
|
|
@ -4,21 +4,7 @@ from paddle import nn
|
|||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
|
||||
class Conv1dBatchNorm(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
weight_attr=None, bias_attr=None):
|
||||
super(Conv1dBatchNorm, self).__init__()
|
||||
# TODO(chenfeiyu): carefully initialize Conv1d's weight
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
|
||||
padding=padding,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
||||
self.bn = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
from parakeet.modules.conv import Conv1dBatchNorm
|
||||
|
||||
|
||||
class Highway(nn.Layer):
|
||||
|
|
|
@ -79,3 +79,20 @@ class Conv1dCell(nn.Conv1d):
|
|||
y_t = paddle.matmul(input, self._reshaped_weight, transpose_y=True)
|
||||
y_t = y_t + self.bias
|
||||
return y_t
|
||||
|
||||
|
||||
class Conv1dBatchNorm(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
weight_attr=None, bias_attr=None):
|
||||
super(Conv1dBatchNorm, self).__init__()
|
||||
# TODO(chenfeiyu): carefully initialize Conv1d's weight
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
|
||||
padding=padding,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
||||
self.bn = nn.BatchNorm1d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn(self.conv(x))
|
||||
|
||||
|
|
|
@ -26,4 +26,4 @@ def shuffle_dim(x, axis, perm=None):
|
|||
|
||||
perm = paddle.to_tensor(perm)
|
||||
out = paddle.gather(x, perm, axis)
|
||||
return out
|
||||
return out
|
||||
|
|
|
@ -2,7 +2,7 @@ import math
|
|||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
def positional_encoding(start_index, length, size, dtype="float32"):
|
||||
def positional_encoding(start_index, length, size, dtype=None):
|
||||
"""
|
||||
Generate standard positional encoding.
|
||||
|
||||
|
@ -22,6 +22,7 @@ def positional_encoding(start_index, length, size, dtype="float32"):
|
|||
"""
|
||||
if (size % 2 != 0):
|
||||
raise ValueError("size should be divisible by 2")
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
channel = paddle.arange(0, size, 2, dtype=dtype)
|
||||
index = paddle.arange(start_index, start_index + length, 1, dtype=dtype)
|
||||
p = paddle.unsqueeze(index, -1) / (10000 ** (channel / float(size)))
|
||||
|
|
|
@ -98,7 +98,4 @@ def load_tests(loader, standard_tests, pattern):
|
|||
suite.addTest(TestMultiheadAttention("test_io", same_qk=True))
|
||||
suite.addTest(TestMultiheadAttention("test_io", same_qk=False))
|
||||
|
||||
suite.addTest(TestDropHeadMultiheadAttention("test_io", same_qk=True))
|
||||
suite.addTest(TestDropHeadMultiheadAttention("test_io", same_qk=False))
|
||||
|
||||
return suite
|
|
@ -4,25 +4,6 @@ paddle.set_default_dtype("float64")
|
|||
paddle.disable_static(paddle.CPUPlace())
|
||||
from parakeet.modules import cbhg
|
||||
|
||||
class TestConv1dBatchNorm(unittest.TestCase):
|
||||
def __init__(self, methodName="runTest", causal=False):
|
||||
super(TestConv1dBatchNorm, self).__init__(methodName)
|
||||
self.causal = causal
|
||||
|
||||
def setUp(self):
|
||||
k = 5
|
||||
paddding = (k - 1, 0) if self.causal else ((k-1) // 2, k //2)
|
||||
self.net = cbhg.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding)
|
||||
|
||||
def test_input_output(self):
|
||||
x = paddle.randn([4, 4, 16])
|
||||
out = self.net(x)
|
||||
out_np = out.numpy()
|
||||
self.assertTupleEqual(out_np.shape, (4, 6, 16))
|
||||
|
||||
def runTest(self):
|
||||
self.test_input_output()
|
||||
|
||||
|
||||
class TestHighway(unittest.TestCase):
|
||||
def test_io(self):
|
||||
|
@ -47,8 +28,6 @@ class TestCBHG(unittest.TestCase):
|
|||
|
||||
def load_tests(loader, standard_tests, pattern):
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(TestConv1dBatchNorm("runTest", True))
|
||||
suite.addTest(TestConv1dBatchNorm("runTest", False))
|
||||
|
||||
suite.addTest(TestHighway("test_io"))
|
||||
suite.addTest(TestCBHG("test_io"))
|
||||
|
|
|
@ -29,4 +29,33 @@ class TestConv1dCell(unittest.TestCase):
|
|||
y2 = self.forward_incremental(x)
|
||||
|
||||
np.testing.assert_allclose(y2.numpy(), y1.numpy())
|
||||
|
||||
|
||||
|
||||
class TestConv1dBatchNorm(unittest.TestCase):
|
||||
def __init__(self, methodName="runTest", causal=False):
|
||||
super(TestConv1dBatchNorm, self).__init__(methodName)
|
||||
self.causal = causal
|
||||
|
||||
def setUp(self):
|
||||
k = 5
|
||||
paddding = (k - 1, 0) if self.causal else ((k-1) // 2, k //2)
|
||||
self.net = conv.Conv1dBatchNorm(4, 6, (k,), 1, padding=paddding)
|
||||
|
||||
def test_input_output(self):
|
||||
x = paddle.randn([4, 4, 16])
|
||||
out = self.net(x)
|
||||
out_np = out.numpy()
|
||||
self.assertTupleEqual(out_np.shape, (4, 6, 16))
|
||||
|
||||
def runTest(self):
|
||||
self.test_input_output()
|
||||
|
||||
|
||||
def load_tests(loader, standard_tests, pattern):
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(TestConv1dBatchNorm("runTest", True))
|
||||
suite.addTest(TestConv1dBatchNorm("runTest", False))
|
||||
|
||||
suite.addTest(TestConv1dCell("test_equality"))
|
||||
|
||||
return suite
|
|
@ -53,4 +53,12 @@ class TestTransformerDecoderLayer(unittest.TestCase):
|
|||
|
||||
self.assertTupleEqual(y.numpy().shape, (4, 32, 64))
|
||||
self.assertTupleEqual(self_attn_weights.numpy().shape, (4, 8, 32, 32))
|
||||
self.assertTupleEqual(cross_attn_weights.numpy().shape, (4, 8, 32, 24))
|
||||
self.assertTupleEqual(cross_attn_weights.numpy().shape, (4, 8, 32, 24))
|
||||
|
||||
|
||||
class TestTransformerTTS(unittest.TestCase):
|
||||
def test_io(self):
|
||||
return
|
||||
net = tts.TransformerTTS(vocab_size, padding_idx, d_model, d_mel, n_heads, d_ffn, positional_encoding_scalar,
|
||||
encoder_layers, decoder_layers, d_prenet, d_postnet, postnet_layers,
|
||||
postnet_kernel_size, max_reduction_factor, dropout)
|
Loading…
Reference in New Issue