Parakeet/parakeet/modules/cbhg.py

91 lines
3.0 KiB
Python

import math
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from parakeet.modules.conv import Conv1dBatchNorm
class Highway(nn.Layer):
def __init__(self, num_features):
super(Highway, self).__init__()
self.H = nn.Linear(num_features, num_features)
self.T = nn.Linear(num_features, num_features,
bias_attr=I.Constant(-1.))
self.num_features = num_features
def forward(self, x):
H = F.relu(self.H(x))
T = F.sigmoid(self.T(x)) # gate
return H * T + x * (1.0 - T)
class CBHG(nn.Layer):
def __init__(self, in_channels, out_channels_per_conv, max_kernel_size,
projection_channels,
num_highways, highway_features,
gru_features):
super(CBHG, self).__init__()
self.conv1d_banks = nn.LayerList(
[Conv1dBatchNorm(in_channels, out_channels_per_conv, (k,),
padding=((k - 1) // 2, k // 2))
for k in range(1, 1 + max_kernel_size)])
self.projections = nn.LayerList()
projection_channels = list(projection_channels)
proj_in_channels = [max_kernel_size *
out_channels_per_conv] + projection_channels
proj_out_channels = projection_channels + \
[in_channels] # ensure residual connection
for c_in, c_out in zip(proj_in_channels, proj_out_channels):
conv = nn.Conv1D(c_in, c_out, (3,), padding=(1, 1))
self.projections.append(conv)
if in_channels != highway_features:
self.pre_highway = nn.Linear(in_channels, highway_features)
self.highways = nn.LayerList(
[Highway(highway_features) for _ in range(num_highways)])
self.gru = nn.GRU(highway_features, gru_features,
direction="bidirectional")
self.in_channels = in_channels
self.out_channels_per_conv = out_channels_per_conv
self.max_kernel_size = max_kernel_size
self.num_projections = 1 + len(projection_channels)
self.num_highways = num_highways
self.highway_features = highway_features
self.gru_features = gru_features
def forward(self, x):
input = x
# conv banks
conv_outputs = []
for conv in self.conv1d_banks:
conv_outputs.append(conv(x))
x = F.relu(paddle.concat(conv_outputs, 1))
# max pool
x = F.max_pool1d(x, 2, stride=1, padding=(0, 1))
# conv1d projections
n_projections = len(self.projections)
for i, conv in enumerate(self.projections):
x = conv(x)
if i != n_projections:
x = F.relu(x)
x += input # residual connection
# highway
x = paddle.transpose(x, [0, 2, 1])
if hasattr(self, "pre_highway"):
x = self.pre_highway(x)
# gru
x, _ = self.gru(x)
return x