From 62959759f9547831a3b5d9b012aa25c7efe95372 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Sat, 5 Dec 2020 22:09:44 +0800 Subject: [PATCH] add linear in decoder prenet --- parakeet/models/transformer_tts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index d49a199..4cc3df3 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer): super(MLPPreNet, self).__init__() self.lin1 = nn.Linear(d_input, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden) + self.lin3 = nn.Linear(d_hidden, d_hidden) self.dropout = dropout def forward(self, x, dropout): l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training) l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training) - return l2 + l3 = self.lin3(l2) + return l3 # NOTE: not used in class CNNPreNet(nn.Layer):