105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
|
import math
|
||
|
import paddle
|
||
|
from paddle import nn
|
||
|
from paddle.nn import functional as F
|
||
|
from paddle.nn import initializer as I
|
||
|
|
||
|
|
||
|
class Conv1dBatchNorm(nn.Layer):
|
||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||
|
weight_attr=None, bias_attr=None):
|
||
|
super(Conv1dBatchNorm, self).__init__()
|
||
|
# TODO(chenfeiyu): carefully initialize Conv1d's weight
|
||
|
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
|
||
|
padding=padding,
|
||
|
weight_attr=weight_attr,
|
||
|
bias_attr=bias_attr)
|
||
|
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
||
|
self.bn = nn.BatchNorm1d(out_channels)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.bn(self.conv(x))
|
||
|
|
||
|
|
||
|
class Highway(nn.Layer):
|
||
|
def __init__(self, num_features):
|
||
|
super(Highway, self).__init__()
|
||
|
self.H = nn.Linear(num_features, num_features)
|
||
|
self.T = nn.Linear(num_features, num_features,
|
||
|
bias_attr=I.Constant(-1.))
|
||
|
|
||
|
self.num_features = num_features
|
||
|
|
||
|
def forward(self, x):
|
||
|
H = F.relu(self.H(x))
|
||
|
T = F.sigmoid(self.T(x)) # gate
|
||
|
return H * T + x * (1.0 - T)
|
||
|
|
||
|
|
||
|
class CBHG(nn.Layer):
|
||
|
def __init__(self, in_channels, out_channels_per_conv, max_kernel_size,
|
||
|
projection_channels,
|
||
|
num_highways, highway_features,
|
||
|
gru_features):
|
||
|
super(CBHG, self).__init__()
|
||
|
self.conv1d_banks = nn.LayerList(
|
||
|
[Conv1dBatchNorm(in_channels, out_channels_per_conv, (k,),
|
||
|
padding=((k - 1) // 2, k // 2))
|
||
|
for k in range(1, 1 + max_kernel_size)])
|
||
|
|
||
|
self.projections = nn.LayerList()
|
||
|
projection_channels = list(projection_channels)
|
||
|
proj_in_channels = [max_kernel_size *
|
||
|
out_channels_per_conv] + projection_channels
|
||
|
proj_out_channels = projection_channels + \
|
||
|
[in_channels] # ensure residual connection
|
||
|
for c_in, c_out in zip(proj_in_channels, proj_out_channels):
|
||
|
conv = nn.Conv1d(c_in, c_out, (3,), padding=(1, 1))
|
||
|
self.projections.append(conv)
|
||
|
|
||
|
if in_channels != highway_features:
|
||
|
self.pre_highway = nn.Linear(in_channels, highway_features)
|
||
|
|
||
|
self.highways = nn.LayerList(
|
||
|
[Highway(highway_features) for _ in range(num_highways)])
|
||
|
|
||
|
self.gru = nn.GRU(highway_features, gru_features,
|
||
|
direction="bidirectional")
|
||
|
|
||
|
self.in_channels = in_channels
|
||
|
self.out_channels_per_conv = out_channels_per_conv
|
||
|
self.max_kernel_size = max_kernel_size
|
||
|
self.num_projections = 1 + len(projection_channels)
|
||
|
self.num_highways = num_highways
|
||
|
self.highway_features = highway_features
|
||
|
self.gru_features = gru_features
|
||
|
|
||
|
def forward(self, x):
|
||
|
input = x
|
||
|
|
||
|
# conv banks
|
||
|
conv_outputs = []
|
||
|
for conv in self.conv1d_banks:
|
||
|
conv_outputs.append(conv(x))
|
||
|
x = F.relu(paddle.concat(conv_outputs, 1))
|
||
|
|
||
|
# max pool
|
||
|
x = F.max_pool1d(x, 2, stride=1, padding=(0, 1))
|
||
|
|
||
|
# conv1d projections
|
||
|
n_projections = len(self.projections)
|
||
|
for i, conv in enumerate(self.projections):
|
||
|
x = conv(x)
|
||
|
if i != n_projections:
|
||
|
x = F.relu(x)
|
||
|
x += input # residual connection
|
||
|
|
||
|
# highway
|
||
|
x = paddle.transpose(x, [0, 2, 1])
|
||
|
if hasattr(self, "pre_highway"):
|
||
|
x = self.pre_highway(x)
|
||
|
|
||
|
# gru
|
||
|
x, _ = self.gru(x)
|
||
|
return x
|