diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index c05e197..a0029ed 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -362,7 +362,7 @@ class TransformerTTS(nn.Layer): postnet_kernel_size, max_reduction_factor, dropout): super(TransformerTTS, self).__init__() # encoder - self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx) + self.embedding = nn.Embedding(vocab_size, d_encoder, padding_idx, weight_attr=I.Uniform(-0.05, 0.05)) self.encoder_prenet = CNNPreNet(d_encoder, d_encoder, d_encoder, 5, 3, dropout) self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) # it may be extended later self.encoder_pe_scalar = self.create_parameter([1], attr=I.Constant(1.))