add linear in decoder prenet
This commit is contained in:
parent
0287f46532
commit
62959759f9
|
@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer):
|
|||
super(MLPPreNet, self).__init__()
|
||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, dropout):
|
||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||
return l2
|
||||
l3 = self.lin3(l2)
|
||||
return l3
|
||||
|
||||
# NOTE: not used in
|
||||
class CNNPreNet(nn.Layer):
|
||||
|
|
Loading…
Reference in New Issue