Merge pull request #67 from iclementine/reborn

fix positional encoding naming conflict
This commit is contained in:
Feiyu Chan 2020-12-21 17:42:37 +08:00 committed by GitHub
commit 9d06ec2d91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -391,7 +391,7 @@ class TransformerTTS(nn.Layer):
padding_idx=frontend.vocab.padding_index,
weight_attr=I.Uniform(-0.05, 0.05))
# position encoding matrix may be extended later
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
@ -399,7 +399,7 @@ class TransformerTTS(nn.Layer):
# decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder)
self.decoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.decoder = TransformerDecoder(

View File

@ -17,10 +17,10 @@ import numpy as np
import paddle
from paddle.nn import functional as F
__all__ = ["positional_encoding"]
__all__ = ["sinusoid_positional_encoding"]
def positional_encoding(start_index, length, size, dtype=None):
def sinusoid_positional_encoding(start_index, length, size, dtype=None):
r"""Generate standard positional encoding matrix.
.. math::