From 62959759f9547831a3b5d9b012aa25c7efe95372 Mon Sep 17 00:00:00 2001
From: chenfeiyu <chenfeiyu@baidu.com>
Date: Sat, 5 Dec 2020 22:09:44 +0800
Subject: [PATCH] add linear in decoder prenet

---
 parakeet/models/transformer_tts.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py
index d49a199..4cc3df3 100644
--- a/parakeet/models/transformer_tts.py
+++ b/parakeet/models/transformer_tts.py
@@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer):
         super(MLPPreNet, self).__init__()
         self.lin1 = nn.Linear(d_input, d_hidden)
         self.lin2 = nn.Linear(d_hidden, d_hidden)
+        self.lin3 = nn.Linear(d_hidden, d_hidden)
         self.dropout = dropout
         
     def forward(self, x, dropout):
         l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
         l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
-        return l2
+        l3 = self.lin3(l2)
+        return l3
 
 # NOTE: not used in 
 class CNNPreNet(nn.Layer):