add linear in decoder prenet
This commit is contained in:
parent
0287f46532
commit
62959759f9
|
@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer):
|
||||||
super(MLPPreNet, self).__init__()
|
super(MLPPreNet, self).__init__()
|
||||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||||
|
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x, dropout):
|
def forward(self, x, dropout):
|
||||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||||
return l2
|
l3 = self.lin3(l2)
|
||||||
|
return l3
|
||||||
|
|
||||||
# NOTE: not used in
|
# NOTE: not used in
|
||||||
class CNNPreNet(nn.Layer):
|
class CNNPreNet(nn.Layer):
|
||||||
|
|
Loading…
Reference in New Issue