add linear in decoder prenet

This commit is contained in:
chenfeiyu 2020-12-05 22:09:44 +08:00
parent 0287f46532
commit 62959759f9
1 changed files with 3 additions and 1 deletions

View File

@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer):
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_hidden)
self.lin3 = nn.Linear(d_hidden, d_hidden)
self.dropout = dropout
def forward(self, x, dropout):
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
return l2
l3 = self.lin3(l2)
return l3
# NOTE: not used in
class CNNPreNet(nn.Layer):