diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index 05ce008..c19feff 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -317,14 +317,14 @@ class MLPPreNet(nn.Layer): super(MLPPreNet, self).__init__() self.lin1 = nn.Linear(d_input, d_hidden) self.lin2 = nn.Linear(d_hidden, d_hidden) - self.lin3 = nn.Linear(d_hidden, d_hidden) + self.lin3 = nn.Linear(d_hidden, d_output) self.dropout = dropout def forward(self, x, dropout): l1 = F.dropout( - F.relu(self.lin1(x)), self.dropout, training=self.training) + F.relu(self.lin1(x)), self.dropout, training=True) l2 = F.dropout( - F.relu(self.lin2(l1)), self.dropout, training=self.training) + F.relu(self.lin2(l1)), self.dropout, training=True) l3 = self.lin3(l2) return l3 @@ -473,7 +473,7 @@ class TransformerTTS(nn.Layer): # twice its length if needed if x.shape[1] * self.r > self.decoder_pe.shape[0]: new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2) - self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder) + self.decoder_pe = pe.sinusoid_positional_encoding(0, new_T, self.d_decoder) pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :] x = x.scale(math.sqrt( self.d_decoder)) + pos_enc * self.decoder_pe_scalar