Merge pull request #67 from iclementine/reborn
fix positional encoding naming conflict
This commit is contained in:
commit
9d06ec2d91
|
@ -391,7 +391,7 @@ class TransformerTTS(nn.Layer):
|
||||||
padding_idx=frontend.vocab.padding_index,
|
padding_idx=frontend.vocab.padding_index,
|
||||||
weight_attr=I.Uniform(-0.05, 0.05))
|
weight_attr=I.Uniform(-0.05, 0.05))
|
||||||
# position encoding matrix may be extended later
|
# position encoding matrix may be extended later
|
||||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
|
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
|
||||||
self.encoder_pe_scalar = self.create_parameter(
|
self.encoder_pe_scalar = self.create_parameter(
|
||||||
[1], attr=I.Constant(1.))
|
[1], attr=I.Constant(1.))
|
||||||
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
|
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
|
||||||
|
@ -399,7 +399,7 @@ class TransformerTTS(nn.Layer):
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
||||||
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
|
self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder)
|
||||||
self.decoder_pe_scalar = self.create_parameter(
|
self.decoder_pe_scalar = self.create_parameter(
|
||||||
[1], attr=I.Constant(1.))
|
[1], attr=I.Constant(1.))
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
|
|
|
@ -17,10 +17,10 @@ import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
__all__ = ["positional_encoding"]
|
__all__ = ["sinusoid_positional_encoding"]
|
||||||
|
|
||||||
|
|
||||||
def positional_encoding(start_index, length, size, dtype=None):
|
def sinusoid_positional_encoding(start_index, length, size, dtype=None):
|
||||||
r"""Generate standard positional encoding matrix.
|
r"""Generate standard positional encoding matrix.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
Loading…
Reference in New Issue