add docstring for transformer_tts and fastspeech
This commit is contained in:
parent
a302bf21f4
commit
f7ec215b9a
|
@ -6,6 +6,7 @@ python -u synthesis.py \
|
||||||
--checkpoint_path='checkpoint/' \
|
--checkpoint_path='checkpoint/' \
|
||||||
--fastspeech_step=71000 \
|
--fastspeech_step=71000 \
|
||||||
--log_dir='./log' \
|
--log_dir='./log' \
|
||||||
|
--config_path='configs/synthesis.yaml' \
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo "Failed in synthesis!"
|
echo "Failed in synthesis!"
|
||||||
|
|
|
@ -13,7 +13,7 @@ python -u train.py \
|
||||||
--transformer_step=160000 \
|
--transformer_step=160000 \
|
||||||
--save_path='./checkpoint' \
|
--save_path='./checkpoint' \
|
||||||
--log_dir='./log' \
|
--log_dir='./log' \
|
||||||
--config_path='config/fastspeech.yaml' \
|
--config_path='configs/fastspeech.yaml' \
|
||||||
#--checkpoint_path='./checkpoint' \
|
#--checkpoint_path='./checkpoint' \
|
||||||
#--fastspeech_step=97000 \
|
#--fastspeech_step=97000 \
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,7 @@ def synthesis(text_input, args):
|
||||||
dec_slf_mask = get_triu_tensor(
|
dec_slf_mask = get_triu_tensor(
|
||||||
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
||||||
dec_slf_mask = fluid.layers.cast(
|
dec_slf_mask = fluid.layers.cast(
|
||||||
dg.to_variable(dec_slf_mask == 0), np.float32)
|
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
||||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
|
@ -157,6 +157,5 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description="Synthesis model")
|
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||||
add_config_options_to_parser(parser)
|
add_config_options_to_parser(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
synthesis(
|
synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.",
|
||||||
"They emphasized the necessity that the information now being furnished be handled with judgment and care.",
|
|
||||||
args)
|
args)
|
||||||
|
|
|
@ -2,14 +2,14 @@
|
||||||
# train model
|
# train model
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
python -u synthesis.py \
|
python -u synthesis.py \
|
||||||
--max_len=600 \
|
--max_len=300 \
|
||||||
--transformer_step=160000 \
|
--transformer_step=120000 \
|
||||||
--vocoder_step=90000 \
|
--vocoder_step=100000 \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--checkpoint_path='./checkpoint' \
|
--checkpoint_path='./checkpoint' \
|
||||||
--log_dir='./log' \
|
--log_dir='./log' \
|
||||||
--sample_path='./sample' \
|
--sample_path='./sample' \
|
||||||
--config_path='config/synthesis.yaml' \
|
--config_path='configs/synthesis.yaml' \
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo "Failed in training!"
|
echo "Failed in training!"
|
||||||
|
|
|
@ -14,7 +14,7 @@ python -u train_transformer.py \
|
||||||
--data_path='../../dataset/LJSpeech-1.1' \
|
--data_path='../../dataset/LJSpeech-1.1' \
|
||||||
--save_path='./checkpoint' \
|
--save_path='./checkpoint' \
|
||||||
--log_dir='./log' \
|
--log_dir='./log' \
|
||||||
--config_path='config/train_transformer.yaml' \
|
--config_path='configs/train_transformer.yaml' \
|
||||||
#--checkpoint_path='./checkpoint' \
|
#--checkpoint_path='./checkpoint' \
|
||||||
#--transformer_step=160000 \
|
#--transformer_step=160000 \
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ python -u train_vocoder.py \
|
||||||
--data_path='../../dataset/LJSpeech-1.1' \
|
--data_path='../../dataset/LJSpeech-1.1' \
|
||||||
--save_path='./checkpoint' \
|
--save_path='./checkpoint' \
|
||||||
--log_dir='./log' \
|
--log_dir='./log' \
|
||||||
--config_path='config/train_vocoder.yaml' \
|
--config_path='configs/train_vocoder.yaml' \
|
||||||
#--checkpoint_path='./checkpoint' \
|
#--checkpoint_path='./checkpoint' \
|
||||||
#--vocoder_step=27000 \
|
#--vocoder_step=27000 \
|
||||||
|
|
||||||
|
|
|
@ -59,15 +59,25 @@ class Decoder(dg.Layer):
|
||||||
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
||||||
"""
|
"""
|
||||||
Decoder layer of FastSpeech.
|
Decoder layer of FastSpeech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
enc_seq (Variable), Shape(B, text_T, C), dtype: float32.
|
enc_seq (Variable): The output of length regulator.
|
||||||
The output of length regulator.
|
Shape: (B, T_text, C), T_text means the timesteps of input text,
|
||||||
enc_pos (Variable, optional): Shape(B, T_mel),
|
dtype: float32.
|
||||||
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
|
enc_pos (Variable): The spectrum position.
|
||||||
|
Shape: (B, T_mel), T_mel means the timesteps of input spectrum,
|
||||||
|
dtype: int64.
|
||||||
|
non_pad_mask (Variable): the mask with non pad.
|
||||||
|
Shape: (B, T_mel, 1),
|
||||||
|
dtype: int64.
|
||||||
|
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
|
||||||
|
Shape: (B, T_mel, T_mel),
|
||||||
|
dtype: int64.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dec_output (Variable), Shape(B, mel_T, C), the decoder output.
|
dec_output (Variable): the decoder output.
|
||||||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
Shape: (B, T_mel, C).
|
||||||
|
dec_slf_attn_list (list[Variable]): the decoder self attention list.
|
||||||
|
Len: n_layers.
|
||||||
"""
|
"""
|
||||||
dec_slf_attn_list = []
|
dec_slf_attn_list = []
|
||||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||||
|
|
|
@ -64,17 +64,24 @@ class Encoder(dg.Layer):
|
||||||
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
||||||
"""
|
"""
|
||||||
Encoder layer of FastSpeech.
|
Encoder layer of FastSpeech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
character (Variable): Shape(B, T_text), dtype: float32. The input text
|
character (Variable): The input text characters.
|
||||||
characters. T_text means the timesteps of input characters.
|
Shape: (B, T_text), T_text means the timesteps of input characters,
|
||||||
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
|
dtype: float32.
|
||||||
position. T_text means the timesteps of input characters.
|
text_pos (Variable): The input text position.
|
||||||
|
Shape: (B, T_text), dtype: int64.
|
||||||
|
non_pad_mask (Variable): the mask with non pad.
|
||||||
|
Shape: (B, T_text, 1),
|
||||||
|
dtype: int64.
|
||||||
|
slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
|
||||||
|
Shape: (B, T_text, T_text),
|
||||||
|
dtype: int64.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
enc_output (Variable), Shape(B, text_T, C), the encoder output.
|
enc_output (Variable), the encoder output. Shape(B, T_text, C)
|
||||||
non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad.
|
non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1)
|
||||||
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list.
|
enc_slf_attn_list (list[Variable]), the encoder self attention list.
|
||||||
|
Len: n_layers.
|
||||||
"""
|
"""
|
||||||
enc_slf_attn_list = []
|
enc_slf_attn_list = []
|
||||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||||
|
|
|
@ -82,34 +82,45 @@ class FastSpeech(dg.Layer):
|
||||||
text_pos,
|
text_pos,
|
||||||
enc_non_pad_mask,
|
enc_non_pad_mask,
|
||||||
dec_non_pad_mask,
|
dec_non_pad_mask,
|
||||||
|
mel_pos=None,
|
||||||
enc_slf_attn_mask=None,
|
enc_slf_attn_mask=None,
|
||||||
dec_slf_attn_mask=None,
|
dec_slf_attn_mask=None,
|
||||||
mel_pos=None,
|
|
||||||
length_target=None,
|
length_target=None,
|
||||||
alpha=1.0):
|
alpha=1.0):
|
||||||
"""
|
"""
|
||||||
FastSpeech model.
|
FastSpeech model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
character (Variable): Shape(B, T_text), dtype: float32. The input text
|
character (Variable): The input text characters.
|
||||||
characters. T_text means the timesteps of input characters.
|
Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32.
|
||||||
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
|
text_pos (Variable): The input text position.
|
||||||
position. T_text means the timesteps of input characters.
|
Shape: (B, T_text), dtype: int64.
|
||||||
mel_pos (Variable, optional): Shape(B, T_mel),
|
mel_pos (Variable, optional): The spectrum position.
|
||||||
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
|
Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64.
|
||||||
length_target (Variable, optional): Shape(B, T_text),
|
enc_non_pad_mask (Variable): the mask with non pad.
|
||||||
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
|
Shape: (B, T_text, 1),
|
||||||
alpha (Constant):
|
dtype: int64.
|
||||||
dtype: float32. The hyperparameter to determine the length of the expanded sequence
|
dec_non_pad_mask (Variable): the mask with non pad.
|
||||||
mel, thereby controlling the voice speed.
|
Shape: (B, T_mel, 1),
|
||||||
|
dtype: int64.
|
||||||
|
enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
|
||||||
|
Shape: (B, T_text, T_text),
|
||||||
|
dtype: int64.
|
||||||
|
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
|
||||||
|
Shape: (B, T_mel, T_mel),
|
||||||
|
dtype: int64.
|
||||||
|
length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
|
||||||
|
Defaults to None. Shape: (B, T_text), dtype: int64.
|
||||||
|
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
|
||||||
|
mel, thereby controlling the voice speed. Defaults to 1.0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet.
|
mel_output (Variable), the mel output before postnet. Shape: (B, T_mel, C),
|
||||||
mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet.
|
mel_output_postnet (Variable), the mel output after postnet. Shape: (B, T_mel, C).
|
||||||
duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute
|
duration_predictor_output (Variable), the duration of phoneme compute with duration predictor.
|
||||||
with duration predictor.
|
Shape: (B, T_text).
|
||||||
enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list.
|
enc_slf_attn_list (List[Variable]), the encoder self attention list. Len: enc_n_layers.
|
||||||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
dec_slf_attn_list (List[Variable]), the decoder self attention list. Len: dec_n_layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoder_output, enc_slf_attn_list = self.encoder(
|
encoder_output, enc_slf_attn_list = self.encoder(
|
||||||
|
@ -118,7 +129,6 @@ class FastSpeech(dg.Layer):
|
||||||
enc_non_pad_mask,
|
enc_non_pad_mask,
|
||||||
slf_attn_mask=enc_slf_attn_mask)
|
slf_attn_mask=enc_slf_attn_mask)
|
||||||
if fluid.framework._dygraph_tracer()._train_mode:
|
if fluid.framework._dygraph_tracer()._train_mode:
|
||||||
|
|
||||||
length_regulator_output, duration_predictor_output = self.length_regulator(
|
length_regulator_output, duration_predictor_output = self.length_regulator(
|
||||||
encoder_output, target=length_target, alpha=alpha)
|
encoder_output, target=length_target, alpha=alpha)
|
||||||
decoder_output, dec_slf_attn_list = self.decoder(
|
decoder_output, dec_slf_attn_list = self.decoder(
|
||||||
|
|
|
@ -51,15 +51,17 @@ class FFTBlock(dg.Layer):
|
||||||
Feed Forward Transformer block in FastSpeech.
|
Feed Forward Transformer block in FastSpeech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input.
|
enc_input (Variable): The embedding characters input.
|
||||||
T means the timesteps of input.
|
Shape: (B, T, C), T means the timesteps of input, dtype: float32.
|
||||||
non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence.
|
non_pad_mask (Variable): The mask of sequence.
|
||||||
slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention.
|
Shape: (B, T, 1), dtype: int64.
|
||||||
len_q means the sequence length of query, len_k means the sequence length of key.
|
slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None.
|
||||||
|
Shape(B, len_q, len_k), len_q means the sequence length of query,
|
||||||
|
len_k means the sequence length of key, dtype: int64.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
output (Variable), Shape(B, T, C), the output after self-attention & ffn.
|
output (Variable), the output after self-attention & ffn. Shape: (B, T, C).
|
||||||
slf_attn (Variable), Shape(B * n_head, T, T), the self attention.
|
slf_attn (Variable), the self attention. Shape: (B * n_head, T, T),
|
||||||
"""
|
"""
|
||||||
output, slf_attn = self.slf_attn(
|
output, slf_attn = self.slf_attn(
|
||||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||||
|
|
|
@ -69,15 +69,17 @@ class LengthRegulator(dg.Layer):
|
||||||
Length Regulator block in FastSpeech.
|
Length Regulator block in FastSpeech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Variable): Shape(B, T, C), dtype: float32. The encoder output.
|
x (Variable): The encoder output.
|
||||||
alpha (Constant): dtype: float32. The hyperparameter to determine the length of
|
Shape: (B, T, C), dtype: float32.
|
||||||
the expanded sequence mel, thereby controlling the voice speed.
|
alpha (float32, optional): The hyperparameter to determine the length of
|
||||||
target (Variable): (Variable, optional): Shape(B, T_text),
|
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
|
||||||
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
|
target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
|
||||||
|
Defaults to None. Shape: (B, T_text), dtype: int64.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
output (Variable), Shape(B, T, C), the output after exppand.
|
output (Variable), the output after exppand. Shape: (B, T, C),
|
||||||
duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor.
|
duration_predictor_output (Variable), the output of duration predictor.
|
||||||
|
Shape: (B, T, C).
|
||||||
"""
|
"""
|
||||||
duration_predictor_output = self.duration_predictor(x)
|
duration_predictor_output = self.duration_predictor(x)
|
||||||
if fluid.framework._dygraph_tracer()._train_mode:
|
if fluid.framework._dygraph_tracer()._train_mode:
|
||||||
|
|
|
@ -31,15 +31,7 @@ class CBHG(dg.Layer):
|
||||||
max_pool_kernel_size=2,
|
max_pool_kernel_size=2,
|
||||||
is_post=False):
|
is_post=False):
|
||||||
super(CBHG, self).__init__()
|
super(CBHG, self).__init__()
|
||||||
"""
|
|
||||||
:param hidden_size: dimension of hidden unit
|
|
||||||
:param batch_size: batch size
|
|
||||||
:param K: # of convolution banks
|
|
||||||
:param projection_size: dimension of projection unit
|
|
||||||
:param num_gru_layers: # of layers of GRUcell
|
|
||||||
:param max_pool_kernel_size: max pooling kernel size
|
|
||||||
:param is_post: whether post processing or not
|
|
||||||
"""
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.projection_size = projection_size
|
self.projection_size = projection_size
|
||||||
self.conv_list = []
|
self.conv_list = []
|
||||||
|
@ -176,7 +168,15 @@ class CBHG(dg.Layer):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
# input_.shape = [N, C, T]
|
"""
|
||||||
|
CBHG Module
|
||||||
|
Args:
|
||||||
|
input_(Variable): The sequentially input.
|
||||||
|
Shape: (B, C, T), dtype: float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Variable): the CBHG output.
|
||||||
|
"""
|
||||||
|
|
||||||
conv_list = []
|
conv_list = []
|
||||||
conv_input = input_
|
conv_input = input_
|
||||||
|
@ -249,6 +249,15 @@ class Highwaynet(dg.Layer):
|
||||||
self.add_sublayer("gates_{}".format(i), gate)
|
self.add_sublayer("gates_{}".format(i), gate)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
|
"""
|
||||||
|
Highway network
|
||||||
|
Args:
|
||||||
|
input_(Variable): The sequentially input.
|
||||||
|
Shape: (B, T, C), dtype: float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Variable): the Highway output.
|
||||||
|
"""
|
||||||
out = input_
|
out = input_
|
||||||
|
|
||||||
for linear, gate in zip(self.linears, self.gates):
|
for linear, gate in zip(self.linears, self.gates):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
||||||
|
|
||||||
|
|
||||||
class Decoder(dg.Layer):
|
class Decoder(dg.Layer):
|
||||||
def __init__(self, num_hidden, config, num_head=4):
|
def __init__(self, num_hidden, config, num_head=4, n_layers=3):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.num_hidden = num_hidden
|
self.num_hidden = num_hidden
|
||||||
self.num_head = num_head
|
self.num_head = num_head
|
||||||
|
@ -58,20 +58,20 @@ class Decoder(dg.Layer):
|
||||||
|
|
||||||
self.selfattn_layers = [
|
self.selfattn_layers = [
|
||||||
MultiheadAttention(num_hidden, num_hidden // num_head,
|
MultiheadAttention(num_hidden, num_hidden // num_head,
|
||||||
num_hidden // num_head) for _ in range(3)
|
num_hidden // num_head) for _ in range(n_layers)
|
||||||
]
|
]
|
||||||
for i, layer in enumerate(self.selfattn_layers):
|
for i, layer in enumerate(self.selfattn_layers):
|
||||||
self.add_sublayer("self_attn_{}".format(i), layer)
|
self.add_sublayer("self_attn_{}".format(i), layer)
|
||||||
self.attn_layers = [
|
self.attn_layers = [
|
||||||
MultiheadAttention(num_hidden, num_hidden // num_head,
|
MultiheadAttention(num_hidden, num_hidden // num_head,
|
||||||
num_hidden // num_head) for _ in range(3)
|
num_hidden // num_head) for _ in range(n_layers)
|
||||||
]
|
]
|
||||||
for i, layer in enumerate(self.attn_layers):
|
for i, layer in enumerate(self.attn_layers):
|
||||||
self.add_sublayer("attn_{}".format(i), layer)
|
self.add_sublayer("attn_{}".format(i), layer)
|
||||||
self.ffns = [
|
self.ffns = [
|
||||||
PositionwiseFeedForward(
|
PositionwiseFeedForward(
|
||||||
num_hidden, num_hidden * num_head, filter_size=1)
|
num_hidden, num_hidden * num_head, filter_size=1)
|
||||||
for _ in range(3)
|
for _ in range(n_layers)
|
||||||
]
|
]
|
||||||
for i, layer in enumerate(self.ffns):
|
for i, layer in enumerate(self.ffns):
|
||||||
self.add_sublayer("ffns_{}".format(i), layer)
|
self.add_sublayer("ffns_{}".format(i), layer)
|
||||||
|
@ -108,6 +108,40 @@ class Decoder(dg.Layer):
|
||||||
m_mask=None,
|
m_mask=None,
|
||||||
m_self_mask=None,
|
m_self_mask=None,
|
||||||
zero_mask=None):
|
zero_mask=None):
|
||||||
|
"""
|
||||||
|
Decoder layer of TransformerTTS.
|
||||||
|
Args:
|
||||||
|
key (Variable): The input key of decoder.
|
||||||
|
Shape: (B, T_text, C), T_text means the timesteps of input text,
|
||||||
|
dtype: float32.
|
||||||
|
value (Variable): The . input value of decoder.
|
||||||
|
Shape: (B, T_text, C), dtype: float32.
|
||||||
|
query (Variable): The input query of decoder.
|
||||||
|
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
|
||||||
|
dtype: float32.
|
||||||
|
positional (Variable): The spectrum position.
|
||||||
|
Shape: (B, T_mel), dtype: int64.
|
||||||
|
mask (Variable): the mask of decoder self attention.
|
||||||
|
Shape: (B, T_mel, T_mel), dtype: int64.
|
||||||
|
m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
|
||||||
|
Shape: (B, T_mel, 1), dtype: int64.
|
||||||
|
m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
|
||||||
|
Shape: (B, T_mel, 1), dtype: int64.
|
||||||
|
zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
|
||||||
|
Shape: (B, T_mel, T_text), dtype: int64.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mel_out (Variable): the decoder output after mel linear projection.
|
||||||
|
Shape: (B, T_mel, C).
|
||||||
|
out (Variable): the decoder output after post mel network.
|
||||||
|
Shape: (B, T_mel, C).
|
||||||
|
stop_tokens (Variable): the stop tokens of output.
|
||||||
|
Shape: (B, T_mel, 1)
|
||||||
|
attn_list (list[Variable]): the encoder-decoder attention list.
|
||||||
|
Len: n_layers.
|
||||||
|
selfattn_list (list[Variable]): the decoder self attention list.
|
||||||
|
Len: n_layers.
|
||||||
|
"""
|
||||||
|
|
||||||
# get decoder mask with triangular matrix
|
# get decoder mask with triangular matrix
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
|
||||||
|
|
||||||
|
|
||||||
class Encoder(dg.Layer):
|
class Encoder(dg.Layer):
|
||||||
def __init__(self, embedding_size, num_hidden, num_head=4):
|
def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3):
|
||||||
super(Encoder, self).__init__()
|
super(Encoder, self).__init__()
|
||||||
self.num_hidden = num_hidden
|
self.num_hidden = num_hidden
|
||||||
self.num_head = num_head
|
self.num_head = num_head
|
||||||
|
@ -42,7 +42,7 @@ class Encoder(dg.Layer):
|
||||||
use_cudnn=True)
|
use_cudnn=True)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
MultiheadAttention(num_hidden, num_hidden // num_head,
|
MultiheadAttention(num_hidden, num_hidden // num_head,
|
||||||
num_hidden // num_head) for _ in range(3)
|
num_hidden // num_head) for _ in range(n_layers)
|
||||||
]
|
]
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
self.add_sublayer("self_attn_{}".format(i), layer)
|
self.add_sublayer("self_attn_{}".format(i), layer)
|
||||||
|
@ -51,12 +51,31 @@ class Encoder(dg.Layer):
|
||||||
num_hidden,
|
num_hidden,
|
||||||
num_hidden * num_head,
|
num_hidden * num_head,
|
||||||
filter_size=1,
|
filter_size=1,
|
||||||
use_cudnn=True) for _ in range(3)
|
use_cudnn=True) for _ in range(n_layers)
|
||||||
]
|
]
|
||||||
for i, layer in enumerate(self.ffns):
|
for i, layer in enumerate(self.ffns):
|
||||||
self.add_sublayer("ffns_{}".format(i), layer)
|
self.add_sublayer("ffns_{}".format(i), layer)
|
||||||
|
|
||||||
def forward(self, x, positional, mask=None, query_mask=None):
|
def forward(self, x, positional, mask=None, query_mask=None):
|
||||||
|
"""
|
||||||
|
Encoder layer of TransformerTTS.
|
||||||
|
Args:
|
||||||
|
x (Variable): The input character.
|
||||||
|
Shape: (B, T_text), T_text means the timesteps of input text,
|
||||||
|
dtype: float32.
|
||||||
|
positional (Variable): The characters position.
|
||||||
|
Shape: (B, T_text), dtype: int64.
|
||||||
|
mask (Variable, optional): the mask of encoder self attention. Defaults to None.
|
||||||
|
Shape: (B, T_text, T_text), dtype: int64.
|
||||||
|
query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
|
||||||
|
Shape: (B, T_text, 1), dtype: int64.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x (Variable): the encoder output.
|
||||||
|
Shape: (B, T_text, C).
|
||||||
|
attentions (list[Variable]): the encoder self attention list.
|
||||||
|
Len: n_layers.
|
||||||
|
"""
|
||||||
|
|
||||||
if fluid.framework._dygraph_tracer()._train_mode:
|
if fluid.framework._dygraph_tracer()._train_mode:
|
||||||
seq_len_key = x.shape[1]
|
seq_len_key = x.shape[1]
|
||||||
|
@ -66,12 +85,12 @@ class Encoder(dg.Layer):
|
||||||
else:
|
else:
|
||||||
query_mask, mask = None, None
|
query_mask, mask = None, None
|
||||||
# Encoder pre_network
|
# Encoder pre_network
|
||||||
x = self.encoder_prenet(x) #(N,T,C)
|
x = self.encoder_prenet(x)
|
||||||
|
|
||||||
# Get positional encoding
|
# Get positional encoding
|
||||||
positional = self.pos_emb(positional)
|
positional = self.pos_emb(positional)
|
||||||
|
|
||||||
x = positional * self.alpha + x #(N, T, C)
|
x = positional * self.alpha + x
|
||||||
|
|
||||||
# Positional dropout
|
# Positional dropout
|
||||||
x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')
|
x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')
|
||||||
|
|
|
@ -81,8 +81,18 @@ class EncoderPrenet(dg.Layer):
|
||||||
low=-k, high=k)))
|
low=-k, high=k)))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Encoder prenet layer of TransformerTTS.
|
||||||
|
Args:
|
||||||
|
x (Variable): The input character.
|
||||||
|
Shape: (B, T_text), T_text means the timesteps of input text,
|
||||||
|
dtype: float32.
|
||||||
|
|
||||||
x = self.embedding(x) #(batch_size, seq_len, embending_size)
|
Returns:
|
||||||
|
(Variable): the encoder prenet output. Shape: (B, T_text, C).
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = self.embedding(x)
|
||||||
x = layers.transpose(x, [0, 2, 1])
|
x = layers.transpose(x, [0, 2, 1])
|
||||||
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
|
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
|
||||||
x = layers.dropout(
|
x = layers.dropout(
|
||||||
|
|
|
@ -93,12 +93,13 @@ class PostConvNet(dg.Layer):
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
"""
|
"""
|
||||||
Post Conv Net.
|
Decocder Post Conv Net of TransformerTTS.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Variable): Shape(B, T, C), dtype: float32. The input value.
|
input (Variable): The result of mel linear projection.
|
||||||
|
Shape: (B, T, C), dtype: float32.
|
||||||
Returns:
|
Returns:
|
||||||
output (Variable), Shape(B, T, C), the result after postconvnet.
|
(Variable): the result after postconvnet. Shape: (B, T, C),
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input = layers.transpose(input, [0, 2, 1])
|
input = layers.transpose(input, [0, 2, 1])
|
||||||
|
|
|
@ -19,11 +19,6 @@ import paddle.fluid.layers as layers
|
||||||
|
|
||||||
class PreNet(dg.Layer):
|
class PreNet(dg.Layer):
|
||||||
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
|
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
|
||||||
"""
|
|
||||||
:param input_size: dimension of input
|
|
||||||
:param hidden_size: dimension of hidden unit
|
|
||||||
:param output_size: dimension of output
|
|
||||||
"""
|
|
||||||
super(PreNet, self).__init__()
|
super(PreNet, self).__init__()
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
@ -52,9 +47,10 @@ class PreNet(dg.Layer):
|
||||||
Pre Net before passing through the network.
|
Pre Net before passing through the network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Variable): Shape(B, T, C), dtype: float32. The input value.
|
x (Variable): The input value.
|
||||||
|
Shape: (B, T, C), dtype: float32.
|
||||||
Returns:
|
Returns:
|
||||||
x (Variable), Shape(B, T, C), the result after pernet.
|
(Variable), the result after pernet. Shape: (B, T, C),
|
||||||
"""
|
"""
|
||||||
x = layers.dropout(
|
x = layers.dropout(
|
||||||
layers.relu(self.linear1(x)),
|
layers.relu(self.linear1(x)),
|
||||||
|
|
|
@ -35,6 +35,46 @@ class TransformerTTS(dg.Layer):
|
||||||
enc_dec_mask=None,
|
enc_dec_mask=None,
|
||||||
dec_query_slf_mask=None,
|
dec_query_slf_mask=None,
|
||||||
dec_query_mask=None):
|
dec_query_mask=None):
|
||||||
|
"""
|
||||||
|
TransformerTTS network.
|
||||||
|
Args:
|
||||||
|
characters (Variable): The input character.
|
||||||
|
Shape: (B, T_text), T_text means the timesteps of input text,
|
||||||
|
dtype: float32.
|
||||||
|
mel_input (Variable): The input query of decoder.
|
||||||
|
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
|
||||||
|
dtype: float32.
|
||||||
|
pos_text (Variable): The characters position.
|
||||||
|
Shape: (B, T_text), dtype: int64.
|
||||||
|
dec_slf_mask (Variable): The spectrum position.
|
||||||
|
Shape: (B, T_mel), dtype: int64.
|
||||||
|
mask (Variable): the mask of decoder self attention.
|
||||||
|
Shape: (B, T_mel, T_mel), dtype: int64.
|
||||||
|
enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None.
|
||||||
|
Shape: (B, T_text, T_text), dtype: int64.
|
||||||
|
enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
|
||||||
|
Shape: (B, T_text, 1), dtype: int64.
|
||||||
|
dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
|
||||||
|
Shape: (B, T_mel, 1), dtype: int64.
|
||||||
|
dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
|
||||||
|
Shape: (B, T_mel, 1), dtype: int64.
|
||||||
|
enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
|
||||||
|
Shape: (B, T_mel, T_text), dtype: int64.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mel_output (Variable): the decoder output after mel linear projection.
|
||||||
|
Shape: (B, T_mel, C).
|
||||||
|
postnet_output (Variable): the decoder output after post mel network.
|
||||||
|
Shape: (B, T_mel, C).
|
||||||
|
stop_preds (Variable): the stop tokens of output.
|
||||||
|
Shape: (B, T_mel, 1)
|
||||||
|
attn_probs (list[Variable]): the encoder-decoder attention list.
|
||||||
|
Len: n_layers.
|
||||||
|
attns_enc (list[Variable]): the encoder self attention list.
|
||||||
|
Len: n_layers.
|
||||||
|
attns_dec (list[Variable]): the decoder self attention list.
|
||||||
|
Len: n_layers.
|
||||||
|
"""
|
||||||
key, attns_enc = self.encoder(
|
key, attns_enc = self.encoder(
|
||||||
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
||||||
|
|
||||||
|
@ -48,5 +88,3 @@ class TransformerTTS(dg.Layer):
|
||||||
m_self_mask=dec_query_slf_mask,
|
m_self_mask=dec_query_slf_mask,
|
||||||
m_mask=dec_query_mask)
|
m_mask=dec_query_mask)
|
||||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
||||||
|
|
||||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
|
||||||
|
|
|
@ -19,10 +19,6 @@ from parakeet.models.transformer_tts.cbhg import CBHG
|
||||||
|
|
||||||
|
|
||||||
class Vocoder(dg.Layer):
|
class Vocoder(dg.Layer):
|
||||||
"""
|
|
||||||
CBHG Network (mel -> linear)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, batch_size):
|
def __init__(self, config, batch_size):
|
||||||
super(Vocoder, self).__init__()
|
super(Vocoder, self).__init__()
|
||||||
self.pre_proj = Conv1D(
|
self.pre_proj = Conv1D(
|
||||||
|
@ -36,6 +32,16 @@ class Vocoder(dg.Layer):
|
||||||
filter_size=1)
|
filter_size=1)
|
||||||
|
|
||||||
def forward(self, mel):
|
def forward(self, mel):
|
||||||
|
"""
|
||||||
|
CBHG Network (mel -> linear)
|
||||||
|
Args:
|
||||||
|
mel (Variable): The input mel spectrum.
|
||||||
|
Shape: (B, C, T), dtype: float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Variable): the linear output.
|
||||||
|
Shape: (B, T, C).
|
||||||
|
"""
|
||||||
mel = layers.transpose(mel, [0, 2, 1])
|
mel = layers.transpose(mel, [0, 2, 1])
|
||||||
mel = self.pre_proj(mel)
|
mel = self.pre_proj(mel)
|
||||||
mel = self.cbhg(mel)
|
mel = self.cbhg(mel)
|
||||||
|
|
|
@ -43,9 +43,10 @@ class DynamicGRU(dg.Layer):
|
||||||
Dynamic GRU block.
|
Dynamic GRU block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Variable): Shape(B, T, C), dtype: float32. The input value.
|
input (Variable): The input value.
|
||||||
|
Shape: (B, T, C), dtype: float32.
|
||||||
Returns:
|
Returns:
|
||||||
output (Variable), Shape(B, T, C), the result compute by GRU.
|
output (Variable), the result compute by GRU. Shape: (B, T, C).
|
||||||
"""
|
"""
|
||||||
hidden = self.h_0
|
hidden = self.h_0
|
||||||
res = []
|
res = []
|
||||||
|
|
|
@ -62,9 +62,10 @@ class PositionwiseFeedForward(dg.Layer):
|
||||||
Feed Forward Network.
|
Feed Forward Network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Variable): Shape(B, T, C), dtype: float32. The input value.
|
input (Variable): The input value.
|
||||||
|
Shape: (B, T, C), dtype: float32.
|
||||||
Returns:
|
Returns:
|
||||||
output (Variable), Shape(B, T, C), the result after FFN.
|
output (Variable), the result after FFN. Shape: (B, T, C).
|
||||||
"""
|
"""
|
||||||
x = layers.transpose(input, [0, 2, 1])
|
x = layers.transpose(input, [0, 2, 1])
|
||||||
#FFN Networt
|
#FFN Networt
|
||||||
|
|
|
@ -1,610 +0,0 @@
|
||||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
from paddle import fluid
|
|
||||||
import paddle.fluid.dygraph as dg
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from . import conv
|
|
||||||
from . import weight_norm
|
|
||||||
|
|
||||||
|
|
||||||
def FC(name_scope,
|
|
||||||
in_features,
|
|
||||||
size,
|
|
||||||
num_flatten_dims=1,
|
|
||||||
relu=False,
|
|
||||||
dropout=0.0,
|
|
||||||
epsilon=1e-30,
|
|
||||||
act=None,
|
|
||||||
is_test=False,
|
|
||||||
dtype="float32"):
|
|
||||||
"""
|
|
||||||
A special Linear Layer, when it is used with dropout, the weight is
|
|
||||||
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
|
|
||||||
"""
|
|
||||||
|
|
||||||
# stds
|
|
||||||
if isinstance(in_features, int):
|
|
||||||
in_features = [in_features]
|
|
||||||
|
|
||||||
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
|
|
||||||
if relu:
|
|
||||||
stds = [std * np.sqrt(2.0) for std in stds]
|
|
||||||
|
|
||||||
weight_inits = [
|
|
||||||
fluid.initializer.NormalInitializer(scale=std) for std in stds
|
|
||||||
]
|
|
||||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
|
||||||
|
|
||||||
# param attrs
|
|
||||||
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
|
|
||||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
|
||||||
|
|
||||||
layer = weight_norm.FC(name_scope,
|
|
||||||
size,
|
|
||||||
num_flatten_dims=num_flatten_dims,
|
|
||||||
param_attr=weight_attrs,
|
|
||||||
bias_attr=bias_attr,
|
|
||||||
act=act,
|
|
||||||
dtype=dtype)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def Conv1D(name_scope,
|
|
||||||
in_channels,
|
|
||||||
num_filters,
|
|
||||||
filter_size=3,
|
|
||||||
dilation=1,
|
|
||||||
groups=None,
|
|
||||||
causal=False,
|
|
||||||
std_mul=1.0,
|
|
||||||
dropout=0.0,
|
|
||||||
use_cudnn=True,
|
|
||||||
act=None,
|
|
||||||
dtype="float32"):
|
|
||||||
"""
|
|
||||||
A special Conv1D Layer, when it is used with dropout, the weight is
|
|
||||||
initialized as
|
|
||||||
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features)))
|
|
||||||
"""
|
|
||||||
# std
|
|
||||||
std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels))
|
|
||||||
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
|
|
||||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
|
||||||
|
|
||||||
# param attrs
|
|
||||||
weight_attr = fluid.ParamAttr(initializer=weight_init)
|
|
||||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
|
||||||
|
|
||||||
layer = conv.Conv1D(
|
|
||||||
name_scope,
|
|
||||||
in_channels,
|
|
||||||
num_filters,
|
|
||||||
filter_size,
|
|
||||||
dilation,
|
|
||||||
groups=groups,
|
|
||||||
causal=causal,
|
|
||||||
param_attr=weight_attr,
|
|
||||||
bias_attr=bias_attr,
|
|
||||||
use_cudnn=use_cudnn,
|
|
||||||
act=act,
|
|
||||||
dtype=dtype)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def Embedding(name_scope,
|
|
||||||
num_embeddings,
|
|
||||||
embed_dim,
|
|
||||||
is_sparse=False,
|
|
||||||
is_distributed=False,
|
|
||||||
padding_idx=None,
|
|
||||||
std=0.01,
|
|
||||||
dtype="float32"):
|
|
||||||
# param attrs
|
|
||||||
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
|
|
||||||
scale=std))
|
|
||||||
layer = dg.Embedding(
|
|
||||||
name_scope, (num_embeddings, embed_dim),
|
|
||||||
padding_idx=padding_idx,
|
|
||||||
param_attr=weight_attr,
|
|
||||||
dtype=dtype)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1DGLU(dg.Layer):
|
|
||||||
"""
|
|
||||||
A Convolution 1D block with GLU activation. It also applys dropout for the
|
|
||||||
input x. It fuses speaker embeddings through a FC activated by softsign. It
|
|
||||||
has residual connection from the input x, and scale the output by
|
|
||||||
np.sqrt(0.5).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
name_scope,
|
|
||||||
n_speakers,
|
|
||||||
speaker_dim,
|
|
||||||
in_channels,
|
|
||||||
num_filters,
|
|
||||||
filter_size,
|
|
||||||
dilation,
|
|
||||||
std_mul=4.0,
|
|
||||||
dropout=0.0,
|
|
||||||
causal=False,
|
|
||||||
residual=True,
|
|
||||||
dtype="float32"):
|
|
||||||
super(Conv1DGLU, self).__init__(name_scope, dtype=dtype)
|
|
||||||
|
|
||||||
# conv spec
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.n_speakers = n_speakers
|
|
||||||
self.speaker_dim = speaker_dim
|
|
||||||
self.num_filters = num_filters
|
|
||||||
self.filter_size = filter_size
|
|
||||||
self.dilation = dilation
|
|
||||||
self.causal = causal
|
|
||||||
self.residual = residual
|
|
||||||
|
|
||||||
# weight init and dropout
|
|
||||||
self.std_mul = std_mul
|
|
||||||
self.dropout = dropout
|
|
||||||
|
|
||||||
if residual:
|
|
||||||
assert (
|
|
||||||
in_channels == num_filters
|
|
||||||
), "this block uses residual connection"\
|
|
||||||
"the input_channes should equals num_filters"
|
|
||||||
|
|
||||||
self.conv = Conv1D(
|
|
||||||
self.full_name(),
|
|
||||||
in_channels,
|
|
||||||
2 * num_filters,
|
|
||||||
filter_size,
|
|
||||||
dilation,
|
|
||||||
causal=causal,
|
|
||||||
std_mul=std_mul,
|
|
||||||
dropout=dropout,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
if n_speakers > 1:
|
|
||||||
assert (speaker_dim is not None
|
|
||||||
), "speaker embed should not be null in multi-speaker case"
|
|
||||||
self.fc = Conv1D(
|
|
||||||
self.full_name(),
|
|
||||||
speaker_dim,
|
|
||||||
num_filters,
|
|
||||||
filter_size=1,
|
|
||||||
dilation=1,
|
|
||||||
causal=False,
|
|
||||||
act="softsign",
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x, speaker_embed_bc1t=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
|
|
||||||
layer, where B means batch_size, C_in means the input channels
|
|
||||||
T means input time steps.
|
|
||||||
speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded
|
|
||||||
speaker embed, where C_sp means speaker embedding size. Note
|
|
||||||
that when using residual connection, the Conv1DGLU does not
|
|
||||||
change the number of channels, so out channels equals input
|
|
||||||
channels.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
|
|
||||||
C_out means the output channels of Conv1DGLU.
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
x = fluid.layers.dropout(x, self.dropout)
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
|
||||||
|
|
||||||
if speaker_embed_bc1t is not None:
|
|
||||||
sp = self.fc(speaker_embed_bc1t)
|
|
||||||
content = content + sp
|
|
||||||
|
|
||||||
# glu
|
|
||||||
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
|
|
||||||
|
|
||||||
if self.residual:
|
|
||||||
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
|
|
||||||
return x
|
|
||||||
|
|
||||||
def add_input(self, x, speaker_embed_bc11=None):
|
|
||||||
"""
|
|
||||||
Inputs:
|
|
||||||
x: shape(B, num_filters, 1, time_steps)
|
|
||||||
speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps)
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
# add step input and produce step output
|
|
||||||
x = fluid.layers.dropout(x, self.dropout)
|
|
||||||
x = self.conv.add_input(x)
|
|
||||||
|
|
||||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
|
||||||
|
|
||||||
if speaker_embed_bc11 is not None:
|
|
||||||
sp = self.fc(speaker_embed_bc11)
|
|
||||||
content = content + sp
|
|
||||||
|
|
||||||
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
|
|
||||||
|
|
||||||
if self.residual:
|
|
||||||
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def Conv1DTranspose(name_scope,
|
|
||||||
in_channels,
|
|
||||||
num_filters,
|
|
||||||
filter_size,
|
|
||||||
padding=0,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
groups=None,
|
|
||||||
std_mul=1.0,
|
|
||||||
dropout=0.0,
|
|
||||||
use_cudnn=True,
|
|
||||||
act=None,
|
|
||||||
dtype="float32"):
|
|
||||||
std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size))
|
|
||||||
weight_init = fluid.initializer.NormalInitializer(scale=std)
|
|
||||||
weight_attr = fluid.ParamAttr(initializer=weight_init)
|
|
||||||
bias_init = fluid.initializer.ConstantInitializer(0.0)
|
|
||||||
bias_attr = fluid.ParamAttr(initializer=bias_init)
|
|
||||||
layer = conv.Conv1DTranspose(
|
|
||||||
name_scope,
|
|
||||||
in_channels,
|
|
||||||
num_filters,
|
|
||||||
filter_size,
|
|
||||||
padding=padding,
|
|
||||||
stride=stride,
|
|
||||||
dilation=dilation,
|
|
||||||
groups=groups,
|
|
||||||
param_attr=weight_attr,
|
|
||||||
bias_attr=bias_attr,
|
|
||||||
use_cudnn=use_cudnn,
|
|
||||||
act=act,
|
|
||||||
dtype=dtype)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def compute_position_embedding(rad):
|
|
||||||
# rad is a transposed radius, shape(embed_dim, n_vocab)
|
|
||||||
embed_dim, n_vocab = rad.shape
|
|
||||||
|
|
||||||
even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32"))
|
|
||||||
odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32"))
|
|
||||||
|
|
||||||
even_rads = fluid.layers.gather(rad, even_dims)
|
|
||||||
odd_rads = fluid.layers.gather(rad, odd_dims)
|
|
||||||
|
|
||||||
sines = fluid.layers.sin(even_rads)
|
|
||||||
cosines = fluid.layers.cos(odd_rads)
|
|
||||||
|
|
||||||
temp = fluid.layers.scatter(rad, even_dims, sines)
|
|
||||||
out = fluid.layers.scatter(temp, odd_dims, cosines)
|
|
||||||
out = fluid.layers.transpose(out, perm=[1, 0])
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def position_encoding_init(n_position,
|
|
||||||
d_pos_vec,
|
|
||||||
position_rate=1.0,
|
|
||||||
sinusoidal=True):
|
|
||||||
""" Init the sinusoid position encoding table """
|
|
||||||
|
|
||||||
# keep idx 0 for padding token position encoding zero vector
|
|
||||||
position_enc = np.array([[
|
|
||||||
position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec)
|
|
||||||
for i in range(d_pos_vec)
|
|
||||||
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
|
|
||||||
|
|
||||||
if sinusoidal:
|
|
||||||
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
|
|
||||||
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
|
|
||||||
|
|
||||||
return position_enc
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbedding(dg.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
name_scope,
|
|
||||||
n_position,
|
|
||||||
d_pos_vec,
|
|
||||||
position_rate=1.0,
|
|
||||||
is_sparse=False,
|
|
||||||
is_distributed=False,
|
|
||||||
param_attr=None,
|
|
||||||
max_norm=None,
|
|
||||||
padding_idx=None,
|
|
||||||
dtype="float32"):
|
|
||||||
super(PositionEmbedding, self).__init__(name_scope, dtype=dtype)
|
|
||||||
self.embed = dg.Embedding(
|
|
||||||
self.full_name(),
|
|
||||||
size=(n_position, d_pos_vec),
|
|
||||||
is_sparse=is_sparse,
|
|
||||||
is_distributed=is_distributed,
|
|
||||||
padding_idx=None,
|
|
||||||
param_attr=param_attr,
|
|
||||||
dtype=dtype)
|
|
||||||
self.set_weight(
|
|
||||||
position_encoding_init(
|
|
||||||
n_position,
|
|
||||||
d_pos_vec,
|
|
||||||
position_rate=position_rate,
|
|
||||||
sinusoidal=False).astype(dtype))
|
|
||||||
|
|
||||||
self._is_sparse = is_sparse
|
|
||||||
self._is_distributed = is_distributed
|
|
||||||
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
|
|
||||||
if self._remote_prefetch:
|
|
||||||
assert self._is_sparse is True and self._is_distributed is False
|
|
||||||
|
|
||||||
self._padding_idx = (-1 if padding_idx is None else padding_idx if
|
|
||||||
padding_idx >= 0 else (n_position + padding_idx))
|
|
||||||
self._position_rate = position_rate
|
|
||||||
self._max_norm = max_norm
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
def set_weight(self, array):
|
|
||||||
assert self.embed._w.shape == list(array.shape), "shape does not match"
|
|
||||||
self.embed._w._ivar.value().get_tensor().set(
|
|
||||||
array, fluid.framework._current_expected_place())
|
|
||||||
|
|
||||||
def forward(self, indices, speaker_position_rate=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
indices (Variable): Shape (B, T, 1), dtype: int64, position
|
|
||||||
indices, where B means the batch size, T means the time steps.
|
|
||||||
speaker_position_rate (Variable | float, optional), position
|
|
||||||
rate. It can be a float point number or a Variable with
|
|
||||||
shape (1,), then this speaker_position_rate is used for every
|
|
||||||
example. It can also be a Variable with shape (B, 1), which
|
|
||||||
contains a speaker position rate for each speaker.
|
|
||||||
Returns:
|
|
||||||
out (Variable): Shape(B, C_pos), position embedding, where C_pos
|
|
||||||
means position embedding size.
|
|
||||||
"""
|
|
||||||
rad = fluid.layers.transpose(self.embed._w, perm=[1, 0])
|
|
||||||
batch_size = indices.shape[0]
|
|
||||||
|
|
||||||
if speaker_position_rate is None:
|
|
||||||
weight = compute_position_embedding(rad)
|
|
||||||
out = self._helper.create_variable_for_type_inference(self._dtype)
|
|
||||||
self._helper.append_op(
|
|
||||||
type="lookup_table",
|
|
||||||
inputs={"Ids": indices,
|
|
||||||
"W": weight},
|
|
||||||
outputs={"Out": out},
|
|
||||||
attrs={
|
|
||||||
"is_sparse": self._is_sparse,
|
|
||||||
"is_distributed": self._is_distributed,
|
|
||||||
"remote_prefetch": self._remote_prefetch,
|
|
||||||
"padding_idx":
|
|
||||||
self._padding_idx, # special value for lookup table op
|
|
||||||
})
|
|
||||||
return out
|
|
||||||
|
|
||||||
elif (np.isscalar(speaker_position_rate) or
|
|
||||||
isinstance(speaker_position_rate, fluid.framework.Variable) and
|
|
||||||
speaker_position_rate.shape == [1, 1]):
|
|
||||||
# # make a weight
|
|
||||||
# scale the weight (the operand for sin & cos)
|
|
||||||
if np.isscalar(speaker_position_rate):
|
|
||||||
scaled_rad = fluid.layers.scale(rad, speaker_position_rate)
|
|
||||||
else:
|
|
||||||
scaled_rad = fluid.layers.elementwise_mul(
|
|
||||||
rad, speaker_position_rate[0])
|
|
||||||
weight = compute_position_embedding(scaled_rad)
|
|
||||||
out = self._helper.create_variable_for_type_inference(self._dtype)
|
|
||||||
self._helper.append_op(
|
|
||||||
type="lookup_table",
|
|
||||||
inputs={"Ids": indices,
|
|
||||||
"W": weight},
|
|
||||||
outputs={"Out": out},
|
|
||||||
attrs={
|
|
||||||
"is_sparse": self._is_sparse,
|
|
||||||
"is_distributed": self._is_distributed,
|
|
||||||
"remote_prefetch": self._remote_prefetch,
|
|
||||||
"padding_idx":
|
|
||||||
self._padding_idx, # special value for lookup table op
|
|
||||||
})
|
|
||||||
return out
|
|
||||||
|
|
||||||
elif np.prod(speaker_position_rate.shape) > 1:
|
|
||||||
assert speaker_position_rate.shape == [batch_size, 1]
|
|
||||||
outputs = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
rate = speaker_position_rate[i] # rate has shape [1]
|
|
||||||
scaled_rad = fluid.layers.elementwise_mul(rad, rate)
|
|
||||||
weight = compute_position_embedding(scaled_rad)
|
|
||||||
out = self._helper.create_variable_for_type_inference(
|
|
||||||
self._dtype)
|
|
||||||
sequence = indices[i]
|
|
||||||
self._helper.append_op(
|
|
||||||
type="lookup_table",
|
|
||||||
inputs={"Ids": sequence,
|
|
||||||
"W": weight},
|
|
||||||
outputs={"Out": out},
|
|
||||||
attrs={
|
|
||||||
"is_sparse": self._is_sparse,
|
|
||||||
"is_distributed": self._is_distributed,
|
|
||||||
"remote_prefetch": self._remote_prefetch,
|
|
||||||
"padding_idx": -1,
|
|
||||||
})
|
|
||||||
outputs.append(out)
|
|
||||||
out = fluid.layers.stack(outputs)
|
|
||||||
return out
|
|
||||||
else:
|
|
||||||
raise Exception("Then you can just use position rate at init")
|
|
||||||
|
|
||||||
|
|
||||||
class Conv1D_GU(dg.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
name_scope,
|
|
||||||
conditioner_dim,
|
|
||||||
in_channels,
|
|
||||||
num_filters,
|
|
||||||
filter_size,
|
|
||||||
dilation,
|
|
||||||
causal=False,
|
|
||||||
residual=True,
|
|
||||||
dtype="float32"):
|
|
||||||
super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
|
|
||||||
|
|
||||||
self.conditioner_dim = conditioner_dim
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.num_filters = num_filters
|
|
||||||
self.filter_size = filter_size
|
|
||||||
self.dilation = dilation
|
|
||||||
self.causal = causal
|
|
||||||
self.residual = residual
|
|
||||||
|
|
||||||
if residual:
|
|
||||||
assert (
|
|
||||||
in_channels == num_filters
|
|
||||||
), "this block uses residual connection"\
|
|
||||||
"the input_channels should equals num_filters"
|
|
||||||
|
|
||||||
self.conv = Conv1D(
|
|
||||||
self.full_name(),
|
|
||||||
in_channels,
|
|
||||||
2 * num_filters,
|
|
||||||
filter_size,
|
|
||||||
dilation,
|
|
||||||
causal=causal,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
self.fc = Conv1D(
|
|
||||||
self.full_name(),
|
|
||||||
conditioner_dim,
|
|
||||||
2 * num_filters,
|
|
||||||
filter_size=1,
|
|
||||||
dilation=1,
|
|
||||||
causal=False,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x, skip=None, conditioner=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
|
|
||||||
layer, where B means batch_size, C_in means the input channels
|
|
||||||
T means input time steps.
|
|
||||||
skip (Variable): Shape(B, C_in, 1, T), skip connection.
|
|
||||||
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
|
|
||||||
conditioner, where C_con is conditioner hidden dim which
|
|
||||||
equals the num of mel bands. Note that when using residual
|
|
||||||
connection, the Conv1D_GU does not change the number of
|
|
||||||
channels, so out channels equals input channels.
|
|
||||||
Returns:
|
|
||||||
x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
|
|
||||||
C_out means the output channels of Conv1D_GU.
|
|
||||||
skip (Variable): Shape(B, C_out, 1, T), skip connection.
|
|
||||||
"""
|
|
||||||
residual = x
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
if conditioner is not None:
|
|
||||||
cond_bias = self.fc(conditioner)
|
|
||||||
x += cond_bias
|
|
||||||
|
|
||||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
|
||||||
|
|
||||||
# Gated Unit.
|
|
||||||
x = fluid.layers.elementwise_mul(
|
|
||||||
fluid.layers.sigmoid(gate), fluid.layers.tanh(content))
|
|
||||||
|
|
||||||
if skip is None:
|
|
||||||
skip = x
|
|
||||||
else:
|
|
||||||
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
|
|
||||||
|
|
||||||
if self.residual:
|
|
||||||
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
|
|
||||||
|
|
||||||
return x, skip
|
|
||||||
|
|
||||||
def add_input(self, x, skip=None, conditioner=None):
|
|
||||||
"""
|
|
||||||
Inputs:
|
|
||||||
x: shape(B, num_filters, 1, time_steps)
|
|
||||||
skip: shape(B, num_filters, 1, time_steps), skip connection
|
|
||||||
conditioner: shape(B, conditioner_dim, 1, time_steps)
|
|
||||||
Outputs:
|
|
||||||
x: shape(B, num_filters, 1, time_steps), where time_steps = 1
|
|
||||||
skip: skip connection, same shape as x
|
|
||||||
"""
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
# add step input and produce step output
|
|
||||||
x = self.conv.add_input(x)
|
|
||||||
|
|
||||||
if conditioner is not None:
|
|
||||||
cond_bias = self.fc(conditioner)
|
|
||||||
x += cond_bias
|
|
||||||
|
|
||||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
|
||||||
|
|
||||||
# Gated Unit.
|
|
||||||
x = fluid.layers.elementwise_mul(
|
|
||||||
fluid.layers.sigmoid(gate), fluid.layers.tanh(content))
|
|
||||||
|
|
||||||
if skip is None:
|
|
||||||
skip = x
|
|
||||||
else:
|
|
||||||
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
|
|
||||||
|
|
||||||
if self.residual:
|
|
||||||
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
|
|
||||||
|
|
||||||
return x, skip
|
|
||||||
|
|
||||||
|
|
||||||
def Conv2DTranspose(name_scope,
|
|
||||||
num_filters,
|
|
||||||
filter_size,
|
|
||||||
padding=0,
|
|
||||||
stride=1,
|
|
||||||
dilation=1,
|
|
||||||
use_cudnn=True,
|
|
||||||
act=None,
|
|
||||||
dtype="float32"):
|
|
||||||
val = 1.0 / (filter_size[0] * filter_size[1])
|
|
||||||
weight_init = fluid.initializer.ConstantInitializer(val)
|
|
||||||
weight_attr = fluid.ParamAttr(initializer=weight_init)
|
|
||||||
|
|
||||||
layer = weight_norm.Conv2DTranspose(
|
|
||||||
name_scope,
|
|
||||||
num_filters,
|
|
||||||
filter_size=filter_size,
|
|
||||||
padding=padding,
|
|
||||||
stride=stride,
|
|
||||||
dilation=dilation,
|
|
||||||
param_attr=weight_attr,
|
|
||||||
use_cudnn=use_cudnn,
|
|
||||||
act=act,
|
|
||||||
dtype=dtype)
|
|
||||||
|
|
||||||
return layer
|
|
|
@ -66,12 +66,17 @@ class ScaledDotProductAttention(dg.Layer):
|
||||||
Scaled Dot Product Attention.
|
Scaled Dot Product Attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
|
key (Variable): The input key of scaled dot product attention.
|
||||||
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
|
Shape: (B, T, C), dtype: float32.
|
||||||
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
|
value (Variable): The input value of scaled dot product attention.
|
||||||
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
|
Shape: (B, T, C), dtype: float32.
|
||||||
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
|
query (Variable): The input query of scaled dot product attention.
|
||||||
dropout (Constant): dtype: float32. The probability of dropout.
|
Shape: (B, T, C), dtype: float32.
|
||||||
|
mask (Variable, optional): The mask of key. Defaults to None.
|
||||||
|
Shape(B, T_q, T_k), dtype: float32.
|
||||||
|
query_mask (Variable, optional): The mask of query. Defaults to None.
|
||||||
|
Shape(B, T_q, T_q), dtype: float32.
|
||||||
|
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
|
||||||
Returns:
|
Returns:
|
||||||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
||||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||||
|
@ -131,14 +136,19 @@ class MultiheadAttention(dg.Layer):
|
||||||
Multihead Attention.
|
Multihead Attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
|
key (Variable): The input key of attention.
|
||||||
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
|
Shape: (B, T, C), dtype: float32.
|
||||||
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
|
value (Variable): The input value of attention.
|
||||||
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
|
Shape: (B, T, C), dtype: float32.
|
||||||
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
|
query_input (Variable): The input query of attention.
|
||||||
|
Shape: (B, T, C), dtype: float32.
|
||||||
|
mask (Variable, optional): The mask of key. Defaults to None.
|
||||||
|
Shape: (B, T_query, T_key), dtype: float32.
|
||||||
|
query_mask (Variable, optional): The mask of query. Defaults to None.
|
||||||
|
Shape: (B, T_query, T_key), dtype: float32.
|
||||||
Returns:
|
Returns:
|
||||||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
result (Variable), the result of mutihead attention. Shape: (B, T, C).
|
||||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
attention (Variable), the attention of key and query. Shape: (num_head * B, T, C)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size = key.shape[0]
|
batch_size = key.shape[0]
|
||||||
|
@ -146,7 +156,6 @@ class MultiheadAttention(dg.Layer):
|
||||||
seq_len_query = query_input.shape[1]
|
seq_len_query = query_input.shape[1]
|
||||||
|
|
||||||
# Make multihead attention
|
# Make multihead attention
|
||||||
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
|
|
||||||
key = layers.reshape(
|
key = layers.reshape(
|
||||||
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
|
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
|
||||||
value = layers.reshape(
|
value = layers.reshape(
|
||||||
|
@ -168,18 +177,6 @@ class MultiheadAttention(dg.Layer):
|
||||||
result, attention = self.scal_attn(
|
result, attention = self.scal_attn(
|
||||||
key, value, query, mask=mask, query_mask=query_mask)
|
key, value, query, mask=mask, query_mask=query_mask)
|
||||||
|
|
||||||
key = layers.reshape(
|
|
||||||
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
|
||||||
value = layers.reshape(
|
|
||||||
layers.transpose(value, [2, 0, 1, 3]),
|
|
||||||
[-1, seq_len_key, self.d_k])
|
|
||||||
query = layers.reshape(
|
|
||||||
layers.transpose(query, [2, 0, 1, 3]),
|
|
||||||
[-1, seq_len_query, self.d_q])
|
|
||||||
|
|
||||||
result, attention = self.scal_attn(
|
|
||||||
key, value, query, mask=mask, query_mask=query_mask)
|
|
||||||
|
|
||||||
# concat all multihead result
|
# concat all multihead result
|
||||||
result = layers.reshape(
|
result = layers.reshape(
|
||||||
result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
||||||
|
|
Loading…
Reference in New Issue