ParakeetEricRoss/parakeet/modules/post_convnet.py

80 lines
3.1 KiB
Python
Raw Normal View History

2020-01-03 16:25:17 +08:00
import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from parakeet.modules.layers import Conv
2020-01-03 16:25:17 +08:00
class PostConvNet(dg.Layer):
def __init__(self,
n_mels=80,
num_hidden=512,
filter_size=5,
padding=0,
num_conv=5,
outputs_per_step=1,
use_cudnn=True,
2020-02-06 17:11:28 +08:00
dropout=0.1,
batchnorm_last=False):
2020-01-03 16:25:17 +08:00
super(PostConvNet, self).__init__()
self.dropout = dropout
2020-01-13 20:37:49 +08:00
self.num_conv = num_conv
2020-02-06 17:11:28 +08:00
self.batchnorm_last = batchnorm_last
2020-01-03 16:25:17 +08:00
self.conv_list = []
self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
2020-01-03 16:25:17 +08:00
out_channels = num_hidden,
filter_size = filter_size,
padding = padding,
use_cudnn = use_cudnn,
data_format = "NCT"))
for _ in range(1, num_conv-1):
self.conv_list.append(Conv(in_channels = num_hidden,
2020-01-03 16:25:17 +08:00
out_channels = num_hidden,
filter_size = filter_size,
padding = padding,
use_cudnn = use_cudnn,
data_format = "NCT") )
self.conv_list.append(Conv(in_channels = num_hidden,
2020-01-03 16:25:17 +08:00
out_channels = n_mels * outputs_per_step,
filter_size = filter_size,
padding = padding,
use_cudnn = use_cudnn,
data_format = "NCT"))
for i, layer in enumerate(self.conv_list):
self.add_sublayer("conv_list_{}".format(i), layer)
self.batch_norm_list = [dg.BatchNorm(num_hidden,
data_layout='NCHW') for _ in range(num_conv-1)]
2020-02-06 17:11:28 +08:00
if self.batchnorm_last:
self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
data_layout='NCHW'))
2020-01-03 16:25:17 +08:00
for i, layer in enumerate(self.batch_norm_list):
self.add_sublayer("batch_norm_list_{}".format(i), layer)
def forward(self, input):
"""
Post Conv Net.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
Returns:
output (Variable), Shape(B, T, C), the result after postconvnet.
"""
2020-01-13 20:37:49 +08:00
2020-01-03 16:25:17 +08:00
input = layers.transpose(input, [0,2,1])
len = input.shape[-1]
2020-01-13 20:37:49 +08:00
for i in range(self.num_conv-1):
batch_norm = self.batch_norm_list[i]
conv = self.conv_list[i]
2020-01-03 16:25:17 +08:00
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
2020-01-13 20:37:49 +08:00
conv = self.conv_list[self.num_conv-1]
input = conv(input)[:,:,:len]
2020-02-06 17:11:28 +08:00
if self.batchnorm_last:
batch_norm = self.batch_norm_list[self.num_conv-1]
input = layers.dropout(batch_norm(input), self.dropout)
output = layers.transpose(input, [0,2,1])
return output