remove the last layer from decoder prenet

This commit is contained in:
chenfeiyu 2020-12-03 15:51:09 +08:00
parent 810f979dba
commit 4df5ad42f6
1 changed files with 1 additions and 3 deletions

View File

@ -273,14 +273,12 @@ class MLPPreNet(nn.Layer):
super(MLPPreNet, self).__init__()
self.lin1 = nn.Linear(d_input, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_hidden)
self.lin3 = nn.Linear(d_output, d_output)
self.dropout = dropout
def forward(self, x, dropout):
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
l3 = self.lin3(l2)
return l3
return l2
# NOTE: not used in
class CNNPreNet(nn.Layer):