91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
import math
|
|
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
from paddle.nn import initializer as I
|
|
|
|
from parakeet.modules.conv import Conv1dBatchNorm
|
|
|
|
|
|
class Highway(nn.Layer):
|
|
def __init__(self, num_features):
|
|
super(Highway, self).__init__()
|
|
self.H = nn.Linear(num_features, num_features)
|
|
self.T = nn.Linear(num_features, num_features,
|
|
bias_attr=I.Constant(-1.))
|
|
|
|
self.num_features = num_features
|
|
|
|
def forward(self, x):
|
|
H = F.relu(self.H(x))
|
|
T = F.sigmoid(self.T(x)) # gate
|
|
return H * T + x * (1.0 - T)
|
|
|
|
|
|
class CBHG(nn.Layer):
|
|
def __init__(self, in_channels, out_channels_per_conv, max_kernel_size,
|
|
projection_channels,
|
|
num_highways, highway_features,
|
|
gru_features):
|
|
super(CBHG, self).__init__()
|
|
self.conv1d_banks = nn.LayerList(
|
|
[Conv1dBatchNorm(in_channels, out_channels_per_conv, (k,),
|
|
padding=((k - 1) // 2, k // 2))
|
|
for k in range(1, 1 + max_kernel_size)])
|
|
|
|
self.projections = nn.LayerList()
|
|
projection_channels = list(projection_channels)
|
|
proj_in_channels = [max_kernel_size *
|
|
out_channels_per_conv] + projection_channels
|
|
proj_out_channels = projection_channels + \
|
|
[in_channels] # ensure residual connection
|
|
for c_in, c_out in zip(proj_in_channels, proj_out_channels):
|
|
conv = nn.Conv1D(c_in, c_out, (3,), padding=(1, 1))
|
|
self.projections.append(conv)
|
|
|
|
if in_channels != highway_features:
|
|
self.pre_highway = nn.Linear(in_channels, highway_features)
|
|
|
|
self.highways = nn.LayerList(
|
|
[Highway(highway_features) for _ in range(num_highways)])
|
|
|
|
self.gru = nn.GRU(highway_features, gru_features,
|
|
direction="bidirectional")
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels_per_conv = out_channels_per_conv
|
|
self.max_kernel_size = max_kernel_size
|
|
self.num_projections = 1 + len(projection_channels)
|
|
self.num_highways = num_highways
|
|
self.highway_features = highway_features
|
|
self.gru_features = gru_features
|
|
|
|
def forward(self, x):
|
|
input = x
|
|
|
|
# conv banks
|
|
conv_outputs = []
|
|
for conv in self.conv1d_banks:
|
|
conv_outputs.append(conv(x))
|
|
x = F.relu(paddle.concat(conv_outputs, 1))
|
|
|
|
# max pool
|
|
x = F.max_pool1d(x, 2, stride=1, padding=(0, 1))
|
|
|
|
# conv1d projections
|
|
n_projections = len(self.projections)
|
|
for i, conv in enumerate(self.projections):
|
|
x = conv(x)
|
|
if i != n_projections:
|
|
x = F.relu(x)
|
|
x += input # residual connection
|
|
|
|
# highway
|
|
x = paddle.transpose(x, [0, 2, 1])
|
|
if hasattr(self, "pre_highway"):
|
|
x = self.pre_highway(x)
|
|
|
|
# gru
|
|
x, _ = self.gru(x)
|
|
return x
|