remove the last layer from decoder prenet
This commit is contained in:
parent
810f979dba
commit
4df5ad42f6
|
@ -273,14 +273,12 @@ class MLPPreNet(nn.Layer):
|
||||||
super(MLPPreNet, self).__init__()
|
super(MLPPreNet, self).__init__()
|
||||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||||
self.lin3 = nn.Linear(d_output, d_output)
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x, dropout):
|
def forward(self, x, dropout):
|
||||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||||
l3 = self.lin3(l2)
|
return l2
|
||||||
return l3
|
|
||||||
|
|
||||||
# NOTE: not used in
|
# NOTE: not used in
|
||||||
class CNNPreNet(nn.Layer):
|
class CNNPreNet(nn.Layer):
|
||||||
|
|
Loading…
Reference in New Issue