add docstring for transformer_tts and fastspeech

This commit is contained in:
lifuchen 2020-03-09 07:16:02 +00:00
parent a302bf21f4
commit f7ec215b9a
23 changed files with 271 additions and 738 deletions

View File

@ -6,6 +6,7 @@ python -u synthesis.py \
--checkpoint_path='checkpoint/' \ --checkpoint_path='checkpoint/' \
--fastspeech_step=71000 \ --fastspeech_step=71000 \
--log_dir='./log' \ --log_dir='./log' \
--config_path='configs/synthesis.yaml' \
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in synthesis!" echo "Failed in synthesis!"

View File

@ -13,7 +13,7 @@ python -u train.py \
--transformer_step=160000 \ --transformer_step=160000 \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='config/fastspeech.yaml' \ --config_path='configs/fastspeech.yaml' \
#--checkpoint_path='./checkpoint' \ #--checkpoint_path='./checkpoint' \
#--fastspeech_step=97000 \ #--fastspeech_step=97000 \

View File

@ -84,7 +84,7 @@ def synthesis(text_input, args):
dec_slf_mask = get_triu_tensor( dec_slf_mask = get_triu_tensor(
mel_input.numpy(), mel_input.numpy()).astype(np.float32) mel_input.numpy(), mel_input.numpy()).astype(np.float32)
dec_slf_mask = fluid.layers.cast( dec_slf_mask = fluid.layers.cast(
dg.to_variable(dec_slf_mask == 0), np.float32) dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = np.arange(1, mel_input.shape[1] + 1)
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
@ -157,6 +157,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Synthesis model") parser = argparse.ArgumentParser(description="Synthesis model")
add_config_options_to_parser(parser) add_config_options_to_parser(parser)
args = parser.parse_args() args = parser.parse_args()
synthesis( synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.",
"They emphasized the necessity that the information now being furnished be handled with judgment and care.", args)
args)

View File

@ -2,14 +2,14 @@
# train model # train model
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
python -u synthesis.py \ python -u synthesis.py \
--max_len=600 \ --max_len=300 \
--transformer_step=160000 \ --transformer_step=120000 \
--vocoder_step=90000 \ --vocoder_step=100000 \
--use_gpu=1 \ --use_gpu=1 \
--checkpoint_path='./checkpoint' \ --checkpoint_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--sample_path='./sample' \ --sample_path='./sample' \
--config_path='config/synthesis.yaml' \ --config_path='configs/synthesis.yaml' \
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

View File

@ -14,7 +14,7 @@ python -u train_transformer.py \
--data_path='../../dataset/LJSpeech-1.1' \ --data_path='../../dataset/LJSpeech-1.1' \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='config/train_transformer.yaml' \ --config_path='configs/train_transformer.yaml' \
#--checkpoint_path='./checkpoint' \ #--checkpoint_path='./checkpoint' \
#--transformer_step=160000 \ #--transformer_step=160000 \

View File

@ -12,7 +12,7 @@ python -u train_vocoder.py \
--data_path='../../dataset/LJSpeech-1.1' \ --data_path='../../dataset/LJSpeech-1.1' \
--save_path='./checkpoint' \ --save_path='./checkpoint' \
--log_dir='./log' \ --log_dir='./log' \
--config_path='config/train_vocoder.yaml' \ --config_path='configs/train_vocoder.yaml' \
#--checkpoint_path='./checkpoint' \ #--checkpoint_path='./checkpoint' \
#--vocoder_step=27000 \ #--vocoder_step=27000 \

View File

@ -59,15 +59,25 @@ class Decoder(dg.Layer):
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
""" """
Decoder layer of FastSpeech. Decoder layer of FastSpeech.
Args: Args:
enc_seq (Variable), Shape(B, text_T, C), dtype: float32. enc_seq (Variable): The output of length regulator.
The output of length regulator. Shape: (B, T_text, C), T_text means the timesteps of input text,
enc_pos (Variable, optional): Shape(B, T_mel), dtype: float32.
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum. enc_pos (Variable): The spectrum position.
Shape: (B, T_mel), T_mel means the timesteps of input spectrum,
dtype: int64.
non_pad_mask (Variable): the mask with non pad.
Shape: (B, T_mel, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
Returns: Returns:
dec_output (Variable), Shape(B, mel_T, C), the decoder output. dec_output (Variable): the decoder output.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. Shape: (B, T_mel, C).
dec_slf_attn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
""" """
dec_slf_attn_list = [] dec_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])

View File

@ -64,17 +64,24 @@ class Encoder(dg.Layer):
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
""" """
Encoder layer of FastSpeech. Encoder layer of FastSpeech.
Args: Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text character (Variable): The input text characters.
characters. T_text means the timesteps of input characters. Shape: (B, T_text), T_text means the timesteps of input characters,
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text dtype: float32.
position. T_text means the timesteps of input characters. text_pos (Variable): The input text position.
Shape: (B, T_text), dtype: int64.
non_pad_mask (Variable): the mask with non pad.
Shape: (B, T_text, 1),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
Returns: Returns:
enc_output (Variable), Shape(B, text_T, C), the encoder output. enc_output (Variable), the encoder output. Shape(B, T_text, C)
non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad. non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1)
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list. enc_slf_attn_list (list[Variable]), the encoder self attention list.
Len: n_layers.
""" """
enc_slf_attn_list = [] enc_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])

View File

@ -82,34 +82,45 @@ class FastSpeech(dg.Layer):
text_pos, text_pos,
enc_non_pad_mask, enc_non_pad_mask,
dec_non_pad_mask, dec_non_pad_mask,
mel_pos=None,
enc_slf_attn_mask=None, enc_slf_attn_mask=None,
dec_slf_attn_mask=None, dec_slf_attn_mask=None,
mel_pos=None,
length_target=None, length_target=None,
alpha=1.0): alpha=1.0):
""" """
FastSpeech model. FastSpeech model.
Args: Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text character (Variable): The input text characters.
characters. T_text means the timesteps of input characters. Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32.
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text text_pos (Variable): The input text position.
position. T_text means the timesteps of input characters. Shape: (B, T_text), dtype: int64.
mel_pos (Variable, optional): Shape(B, T_mel), mel_pos (Variable, optional): The spectrum position.
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum. Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64.
length_target (Variable, optional): Shape(B, T_text), enc_non_pad_mask (Variable): the mask with non pad.
dtype: int64. The duration of phoneme compute from pretrained transformerTTS. Shape: (B, T_text, 1),
alpha (Constant): dtype: int64.
dtype: float32. The hyperparameter to determine the length of the expanded sequence dec_non_pad_mask (Variable): the mask with non pad.
mel, thereby controlling the voice speed. Shape: (B, T_mel, 1),
dtype: int64.
enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
Shape: (B, T_text, T_text),
dtype: int64.
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
Shape: (B, T_mel, T_mel),
dtype: int64.
length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64.
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed. Defaults to 1.0.
Returns: Returns:
mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet. mel_output (Variable), the mel output before postnet. Shape: (B, T_mel, C),
mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet. mel_output_postnet (Variable), the mel output after postnet. Shape: (B, T_mel, C).
duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute duration_predictor_output (Variable), the duration of phoneme compute with duration predictor.
with duration predictor. Shape: (B, T_text).
enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list. enc_slf_attn_list (List[Variable]), the encoder self attention list. Len: enc_n_layers.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. dec_slf_attn_list (List[Variable]), the decoder self attention list. Len: dec_n_layers.
""" """
encoder_output, enc_slf_attn_list = self.encoder( encoder_output, enc_slf_attn_list = self.encoder(
@ -118,7 +129,6 @@ class FastSpeech(dg.Layer):
enc_non_pad_mask, enc_non_pad_mask,
slf_attn_mask=enc_slf_attn_mask) slf_attn_mask=enc_slf_attn_mask)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
length_regulator_output, duration_predictor_output = self.length_regulator( length_regulator_output, duration_predictor_output = self.length_regulator(
encoder_output, target=length_target, alpha=alpha) encoder_output, target=length_target, alpha=alpha)
decoder_output, dec_slf_attn_list = self.decoder( decoder_output, dec_slf_attn_list = self.decoder(

View File

@ -51,15 +51,17 @@ class FFTBlock(dg.Layer):
Feed Forward Transformer block in FastSpeech. Feed Forward Transformer block in FastSpeech.
Args: Args:
enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input. enc_input (Variable): The embedding characters input.
T means the timesteps of input. Shape: (B, T, C), T means the timesteps of input, dtype: float32.
non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence. non_pad_mask (Variable): The mask of sequence.
slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention. Shape: (B, T, 1), dtype: int64.
len_q means the sequence length of query, len_k means the sequence length of key. slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None.
Shape(B, len_q, len_k), len_q means the sequence length of query,
len_k means the sequence length of key, dtype: int64.
Returns: Returns:
output (Variable), Shape(B, T, C), the output after self-attention & ffn. output (Variable), the output after self-attention & ffn. Shape: (B, T, C).
slf_attn (Variable), Shape(B * n_head, T, T), the self attention. slf_attn (Variable), the self attention. Shape: (B * n_head, T, T),
""" """
output, slf_attn = self.slf_attn( output, slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_input, enc_input, enc_input, mask=slf_attn_mask)

View File

@ -69,15 +69,17 @@ class LengthRegulator(dg.Layer):
Length Regulator block in FastSpeech. Length Regulator block in FastSpeech.
Args: Args:
x (Variable): Shape(B, T, C), dtype: float32. The encoder output. x (Variable): The encoder output.
alpha (Constant): dtype: float32. The hyperparameter to determine the length of Shape: (B, T, C), dtype: float32.
the expanded sequence mel, thereby controlling the voice speed. alpha (float32, optional): The hyperparameter to determine the length of
target (Variable): (Variable, optional): Shape(B, T_text), the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
dtype: int64. The duration of phoneme compute from pretrained transformerTTS. target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
Defaults to None. Shape: (B, T_text), dtype: int64.
Returns: Returns:
output (Variable), Shape(B, T, C), the output after exppand. output (Variable), the output after exppand. Shape: (B, T, C),
duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor. duration_predictor_output (Variable), the output of duration predictor.
Shape: (B, T, C).
""" """
duration_predictor_output = self.duration_predictor(x) duration_predictor_output = self.duration_predictor(x)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:

View File

@ -31,15 +31,7 @@ class CBHG(dg.Layer):
max_pool_kernel_size=2, max_pool_kernel_size=2,
is_post=False): is_post=False):
super(CBHG, self).__init__() super(CBHG, self).__init__()
"""
:param hidden_size: dimension of hidden unit
:param batch_size: batch size
:param K: # of convolution banks
:param projection_size: dimension of projection unit
:param num_gru_layers: # of layers of GRUcell
:param max_pool_kernel_size: max pooling kernel size
:param is_post: whether post processing or not
"""
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.projection_size = projection_size self.projection_size = projection_size
self.conv_list = [] self.conv_list = []
@ -176,7 +168,15 @@ class CBHG(dg.Layer):
return x return x
def forward(self, input_): def forward(self, input_):
# input_.shape = [N, C, T] """
CBHG Module
Args:
input_(Variable): The sequentially input.
Shape: (B, C, T), dtype: float32.
Returns:
(Variable): the CBHG output.
"""
conv_list = [] conv_list = []
conv_input = input_ conv_input = input_
@ -249,6 +249,15 @@ class Highwaynet(dg.Layer):
self.add_sublayer("gates_{}".format(i), gate) self.add_sublayer("gates_{}".format(i), gate)
def forward(self, input_): def forward(self, input_):
"""
Highway network
Args:
input_(Variable): The sequentially input.
Shape: (B, T, C), dtype: float32.
Returns:
(Variable): the Highway output.
"""
out = input_ out = input_
for linear, gate in zip(self.linears, self.gates): for linear, gate in zip(self.linears, self.gates):

View File

@ -22,7 +22,7 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
class Decoder(dg.Layer): class Decoder(dg.Layer):
def __init__(self, num_hidden, config, num_head=4): def __init__(self, num_hidden, config, num_head=4, n_layers=3):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
@ -58,20 +58,20 @@ class Decoder(dg.Layer):
self.selfattn_layers = [ self.selfattn_layers = [
MultiheadAttention(num_hidden, num_hidden // num_head, MultiheadAttention(num_hidden, num_hidden // num_head,
num_hidden // num_head) for _ in range(3) num_hidden // num_head) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.selfattn_layers): for i, layer in enumerate(self.selfattn_layers):
self.add_sublayer("self_attn_{}".format(i), layer) self.add_sublayer("self_attn_{}".format(i), layer)
self.attn_layers = [ self.attn_layers = [
MultiheadAttention(num_hidden, num_hidden // num_head, MultiheadAttention(num_hidden, num_hidden // num_head,
num_hidden // num_head) for _ in range(3) num_hidden // num_head) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.attn_layers): for i, layer in enumerate(self.attn_layers):
self.add_sublayer("attn_{}".format(i), layer) self.add_sublayer("attn_{}".format(i), layer)
self.ffns = [ self.ffns = [
PositionwiseFeedForward( PositionwiseFeedForward(
num_hidden, num_hidden * num_head, filter_size=1) num_hidden, num_hidden * num_head, filter_size=1)
for _ in range(3) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.ffns): for i, layer in enumerate(self.ffns):
self.add_sublayer("ffns_{}".format(i), layer) self.add_sublayer("ffns_{}".format(i), layer)
@ -108,6 +108,40 @@ class Decoder(dg.Layer):
m_mask=None, m_mask=None,
m_self_mask=None, m_self_mask=None,
zero_mask=None): zero_mask=None):
"""
Decoder layer of TransformerTTS.
Args:
key (Variable): The input key of decoder.
Shape: (B, T_text, C), T_text means the timesteps of input text,
dtype: float32.
value (Variable): The . input value of decoder.
Shape: (B, T_text, C), dtype: float32.
query (Variable): The input query of decoder.
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
dtype: float32.
positional (Variable): The spectrum position.
Shape: (B, T_mel), dtype: int64.
mask (Variable): the mask of decoder self attention.
Shape: (B, T_mel, T_mel), dtype: int64.
m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns:
mel_out (Variable): the decoder output after mel linear projection.
Shape: (B, T_mel, C).
out (Variable): the decoder output after post mel network.
Shape: (B, T_mel, C).
stop_tokens (Variable): the stop tokens of output.
Shape: (B, T_mel, 1)
attn_list (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
selfattn_list (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
# get decoder mask with triangular matrix # get decoder mask with triangular matrix
@ -121,7 +155,7 @@ class Decoder(dg.Layer):
else: else:
m_mask, m_self_mask, zero_mask = None, None, None m_mask, m_self_mask, zero_mask = None, None, None
# Decoder pre-network # Decoder pre-network
query = self.decoder_prenet(query) query = self.decoder_prenet(query)
# Centered position # Centered position

View File

@ -20,7 +20,7 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
class Encoder(dg.Layer): class Encoder(dg.Layer):
def __init__(self, embedding_size, num_hidden, num_head=4): def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.num_head = num_head self.num_head = num_head
@ -42,7 +42,7 @@ class Encoder(dg.Layer):
use_cudnn=True) use_cudnn=True)
self.layers = [ self.layers = [
MultiheadAttention(num_hidden, num_hidden // num_head, MultiheadAttention(num_hidden, num_hidden // num_head,
num_hidden // num_head) for _ in range(3) num_hidden // num_head) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
self.add_sublayer("self_attn_{}".format(i), layer) self.add_sublayer("self_attn_{}".format(i), layer)
@ -51,12 +51,31 @@ class Encoder(dg.Layer):
num_hidden, num_hidden,
num_hidden * num_head, num_hidden * num_head,
filter_size=1, filter_size=1,
use_cudnn=True) for _ in range(3) use_cudnn=True) for _ in range(n_layers)
] ]
for i, layer in enumerate(self.ffns): for i, layer in enumerate(self.ffns):
self.add_sublayer("ffns_{}".format(i), layer) self.add_sublayer("ffns_{}".format(i), layer)
def forward(self, x, positional, mask=None, query_mask=None): def forward(self, x, positional, mask=None, query_mask=None):
"""
Encoder layer of TransformerTTS.
Args:
x (Variable): The input character.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
positional (Variable): The characters position.
Shape: (B, T_text), dtype: int64.
mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
Returns:
x (Variable): the encoder output.
Shape: (B, T_text, C).
attentions (list[Variable]): the encoder self attention list.
Len: n_layers.
"""
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
seq_len_key = x.shape[1] seq_len_key = x.shape[1]
@ -66,12 +85,12 @@ class Encoder(dg.Layer):
else: else:
query_mask, mask = None, None query_mask, mask = None, None
# Encoder pre_network # Encoder pre_network
x = self.encoder_prenet(x) #(N,T,C) x = self.encoder_prenet(x)
# Get positional encoding # Get positional encoding
positional = self.pos_emb(positional) positional = self.pos_emb(positional)
x = positional * self.alpha + x #(N, T, C) x = positional * self.alpha + x
# Positional dropout # Positional dropout
x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train') x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')

View File

@ -81,8 +81,18 @@ class EncoderPrenet(dg.Layer):
low=-k, high=k))) low=-k, high=k)))
def forward(self, x): def forward(self, x):
"""
Encoder prenet layer of TransformerTTS.
Args:
x (Variable): The input character.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
Returns:
(Variable): the encoder prenet output. Shape: (B, T_text, C).
"""
x = self.embedding(x) #(batch_size, seq_len, embending_size) x = self.embedding(x)
x = layers.transpose(x, [0, 2, 1]) x = layers.transpose(x, [0, 2, 1])
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list): for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
x = layers.dropout( x = layers.dropout(

View File

@ -93,12 +93,13 @@ class PostConvNet(dg.Layer):
def forward(self, input): def forward(self, input):
""" """
Post Conv Net. Decocder Post Conv Net of TransformerTTS.
Args: Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value. input (Variable): The result of mel linear projection.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
output (Variable), Shape(B, T, C), the result after postconvnet. (Variable): the result after postconvnet. Shape: (B, T, C),
""" """
input = layers.transpose(input, [0, 2, 1]) input = layers.transpose(input, [0, 2, 1])

View File

@ -19,11 +19,6 @@ import paddle.fluid.layers as layers
class PreNet(dg.Layer): class PreNet(dg.Layer):
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2): def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
"""
:param input_size: dimension of input
:param hidden_size: dimension of hidden unit
:param output_size: dimension of output
"""
super(PreNet, self).__init__() super(PreNet, self).__init__()
self.input_size = input_size self.input_size = input_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -52,9 +47,10 @@ class PreNet(dg.Layer):
Pre Net before passing through the network. Pre Net before passing through the network.
Args: Args:
x (Variable): Shape(B, T, C), dtype: float32. The input value. x (Variable): The input value.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
x (Variable), Shape(B, T, C), the result after pernet. (Variable), the result after pernet. Shape: (B, T, C),
""" """
x = layers.dropout( x = layers.dropout(
layers.relu(self.linear1(x)), layers.relu(self.linear1(x)),

View File

@ -35,6 +35,46 @@ class TransformerTTS(dg.Layer):
enc_dec_mask=None, enc_dec_mask=None,
dec_query_slf_mask=None, dec_query_slf_mask=None,
dec_query_mask=None): dec_query_mask=None):
"""
TransformerTTS network.
Args:
characters (Variable): The input character.
Shape: (B, T_text), T_text means the timesteps of input text,
dtype: float32.
mel_input (Variable): The input query of decoder.
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
dtype: float32.
pos_text (Variable): The characters position.
Shape: (B, T_text), dtype: int64.
dec_slf_mask (Variable): The spectrum position.
Shape: (B, T_mel), dtype: int64.
mask (Variable): the mask of decoder self attention.
Shape: (B, T_mel, T_mel), dtype: int64.
enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None.
Shape: (B, T_text, T_text), dtype: int64.
enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
Shape: (B, T_text, 1), dtype: int64.
dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
Shape: (B, T_mel, 1), dtype: int64.
enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
Shape: (B, T_mel, T_text), dtype: int64.
Returns:
mel_output (Variable): the decoder output after mel linear projection.
Shape: (B, T_mel, C).
postnet_output (Variable): the decoder output after post mel network.
Shape: (B, T_mel, C).
stop_preds (Variable): the stop tokens of output.
Shape: (B, T_mel, 1)
attn_probs (list[Variable]): the encoder-decoder attention list.
Len: n_layers.
attns_enc (list[Variable]): the encoder self attention list.
Len: n_layers.
attns_dec (list[Variable]): the decoder self attention list.
Len: n_layers.
"""
key, attns_enc = self.encoder( key, attns_enc = self.encoder(
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
@ -48,5 +88,3 @@ class TransformerTTS(dg.Layer):
m_self_mask=dec_query_slf_mask, m_self_mask=dec_query_slf_mask,
m_mask=dec_query_mask) m_mask=dec_query_mask)
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec

View File

@ -19,10 +19,6 @@ from parakeet.models.transformer_tts.cbhg import CBHG
class Vocoder(dg.Layer): class Vocoder(dg.Layer):
"""
CBHG Network (mel -> linear)
"""
def __init__(self, config, batch_size): def __init__(self, config, batch_size):
super(Vocoder, self).__init__() super(Vocoder, self).__init__()
self.pre_proj = Conv1D( self.pre_proj = Conv1D(
@ -36,6 +32,16 @@ class Vocoder(dg.Layer):
filter_size=1) filter_size=1)
def forward(self, mel): def forward(self, mel):
"""
CBHG Network (mel -> linear)
Args:
mel (Variable): The input mel spectrum.
Shape: (B, C, T), dtype: float32.
Returns:
(Variable): the linear output.
Shape: (B, T, C).
"""
mel = layers.transpose(mel, [0, 2, 1]) mel = layers.transpose(mel, [0, 2, 1])
mel = self.pre_proj(mel) mel = self.pre_proj(mel)
mel = self.cbhg(mel) mel = self.cbhg(mel)

View File

@ -43,9 +43,10 @@ class DynamicGRU(dg.Layer):
Dynamic GRU block. Dynamic GRU block.
Args: Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value. input (Variable): The input value.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
output (Variable), Shape(B, T, C), the result compute by GRU. output (Variable), the result compute by GRU. Shape: (B, T, C).
""" """
hidden = self.h_0 hidden = self.h_0
res = [] res = []

View File

@ -62,9 +62,10 @@ class PositionwiseFeedForward(dg.Layer):
Feed Forward Network. Feed Forward Network.
Args: Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value. input (Variable): The input value.
Shape: (B, T, C), dtype: float32.
Returns: Returns:
output (Variable), Shape(B, T, C), the result after FFN. output (Variable), the result after FFN. Shape: (B, T, C).
""" """
x = layers.transpose(input, [0, 2, 1]) x = layers.transpose(input, [0, 2, 1])
#FFN Networt #FFN Networt

View File

@ -1,610 +0,0 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
import numpy as np
from . import conv
from . import weight_norm
def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
relu=False,
dropout=0.0,
epsilon=1e-30,
act=None,
is_test=False,
dtype="float32"):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if isinstance(in_features, int):
in_features = [in_features]
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
if relu:
stds = [std * np.sqrt(2.0) for std in stds]
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.FC(name_scope,
size,
num_flatten_dims=num_flatten_dims,
param_attr=weight_attrs,
bias_attr=bias_attr,
act=act,
dtype=dtype)
return layer
def Conv1D(name_scope,
in_channels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features)))
"""
# std
std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels))
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1D(
name_scope,
in_channels,
num_filters,
filter_size,
dilation,
groups=groups,
causal=causal,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def Embedding(name_scope,
num_embeddings,
embed_dim,
is_sparse=False,
is_distributed=False,
padding_idx=None,
std=0.01,
dtype="float32"):
# param attrs
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=std))
layer = dg.Embedding(
name_scope, (num_embeddings, embed_dim),
padding_idx=padding_idx,
param_attr=weight_attr,
dtype=dtype)
return layer
class Conv1DGLU(dg.Layer):
"""
A Convolution 1D block with GLU activation. It also applys dropout for the
input x. It fuses speaker embeddings through a FC activated by softsign. It
has residual connection from the input x, and scale the output by
np.sqrt(0.5).
"""
def __init__(self,
name_scope,
n_speakers,
speaker_dim,
in_channels,
num_filters,
filter_size,
dilation,
std_mul=4.0,
dropout=0.0,
causal=False,
residual=True,
dtype="float32"):
super(Conv1DGLU, self).__init__(name_scope, dtype=dtype)
# conv spec
self.in_channels = in_channels
self.n_speakers = n_speakers
self.speaker_dim = speaker_dim
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
# weight init and dropout
self.std_mul = std_mul
self.dropout = dropout
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channes should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
std_mul=std_mul,
dropout=dropout,
dtype=dtype)
if n_speakers > 1:
assert (speaker_dim is not None
), "speaker embed should not be null in multi-speaker case"
self.fc = Conv1D(
self.full_name(),
speaker_dim,
num_filters,
filter_size=1,
dilation=1,
causal=False,
act="softsign",
dtype=dtype)
def forward(self, x, speaker_embed_bc1t=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = fluid.layers.dropout(x, self.dropout)
x = self.conv(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc1t is not None:
sp = self.fc(speaker_embed_bc1t)
content = content + sp
# glu
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def add_input(self, x, speaker_embed_bc11=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual = x
# add step input and produce step output
x = fluid.layers.dropout(x, self.dropout)
x = self.conv.add_input(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc11 is not None:
sp = self.fc(speaker_embed_bc11)
content = content + sp
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def Conv1DTranspose(name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size))
weight_init = fluid.initializer.NormalInitializer(scale=std)
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_init = fluid.initializer.ConstantInitializer(0.0)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1DTranspose(
name_scope,
in_channels,
num_filters,
filter_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def compute_position_embedding(rad):
# rad is a transposed radius, shape(embed_dim, n_vocab)
embed_dim, n_vocab = rad.shape
even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32"))
odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32"))
even_rads = fluid.layers.gather(rad, even_dims)
odd_rads = fluid.layers.gather(rad, odd_dims)
sines = fluid.layers.sin(even_rads)
cosines = fluid.layers.cos(odd_rads)
temp = fluid.layers.scatter(rad, even_dims, sines)
out = fluid.layers.scatter(temp, odd_dims, cosines)
out = fluid.layers.transpose(out, perm=[1, 0])
return out
def position_encoding_init(n_position,
d_pos_vec,
position_rate=1.0,
sinusoidal=True):
""" Init the sinusoid position encoding table """
# keep idx 0 for padding token position encoding zero vector
position_enc = np.array([[
position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec)
for i in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
if sinusoidal:
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return position_enc
class PositionEmbedding(dg.Layer):
def __init__(self,
name_scope,
n_position,
d_pos_vec,
position_rate=1.0,
is_sparse=False,
is_distributed=False,
param_attr=None,
max_norm=None,
padding_idx=None,
dtype="float32"):
super(PositionEmbedding, self).__init__(name_scope, dtype=dtype)
self.embed = dg.Embedding(
self.full_name(),
size=(n_position, d_pos_vec),
is_sparse=is_sparse,
is_distributed=is_distributed,
padding_idx=None,
param_attr=param_attr,
dtype=dtype)
self.set_weight(
position_encoding_init(
n_position,
d_pos_vec,
position_rate=position_rate,
sinusoidal=False).astype(dtype))
self._is_sparse = is_sparse
self._is_distributed = is_distributed
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
if self._remote_prefetch:
assert self._is_sparse is True and self._is_distributed is False
self._padding_idx = (-1 if padding_idx is None else padding_idx if
padding_idx >= 0 else (n_position + padding_idx))
self._position_rate = position_rate
self._max_norm = max_norm
self._dtype = dtype
def set_weight(self, array):
assert self.embed._w.shape == list(array.shape), "shape does not match"
self.embed._w._ivar.value().get_tensor().set(
array, fluid.framework._current_expected_place())
def forward(self, indices, speaker_position_rate=None):
"""
Args:
indices (Variable): Shape (B, T, 1), dtype: int64, position
indices, where B means the batch size, T means the time steps.
speaker_position_rate (Variable | float, optional), position
rate. It can be a float point number or a Variable with
shape (1,), then this speaker_position_rate is used for every
example. It can also be a Variable with shape (B, 1), which
contains a speaker position rate for each speaker.
Returns:
out (Variable): Shape(B, C_pos), position embedding, where C_pos
means position embedding size.
"""
rad = fluid.layers.transpose(self.embed._w, perm=[1, 0])
batch_size = indices.shape[0]
if speaker_position_rate is None:
weight = compute_position_embedding(rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif (np.isscalar(speaker_position_rate) or
isinstance(speaker_position_rate, fluid.framework.Variable) and
speaker_position_rate.shape == [1, 1]):
# # make a weight
# scale the weight (the operand for sin & cos)
if np.isscalar(speaker_position_rate):
scaled_rad = fluid.layers.scale(rad, speaker_position_rate)
else:
scaled_rad = fluid.layers.elementwise_mul(
rad, speaker_position_rate[0])
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif np.prod(speaker_position_rate.shape) > 1:
assert speaker_position_rate.shape == [batch_size, 1]
outputs = []
for i in range(batch_size):
rate = speaker_position_rate[i] # rate has shape [1]
scaled_rad = fluid.layers.elementwise_mul(rad, rate)
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(
self._dtype)
sequence = indices[i]
self._helper.append_op(
type="lookup_table",
inputs={"Ids": sequence,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx": -1,
})
outputs.append(out)
out = fluid.layers.stack(outputs)
return out
else:
raise Exception("Then you can just use position rate at init")
class Conv1D_GU(dg.Layer):
def __init__(self,
name_scope,
conditioner_dim,
in_channels,
num_filters,
filter_size,
dilation,
causal=False,
residual=True,
dtype="float32"):
super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
self.conditioner_dim = conditioner_dim
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channels should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
dtype=dtype)
self.fc = Conv1D(
self.full_name(),
conditioner_dim,
2 * num_filters,
filter_size=1,
dilation=1,
causal=False,
dtype=dtype)
def forward(self, x, skip=None, conditioner=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
skip (Variable): Shape(B, C_in, 1, T), skip connection.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1D_GU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
C_out means the output channels of Conv1D_GU.
skip (Variable): Shape(B, C_out, 1, T), skip connection.
"""
residual = x
x = self.conv(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(
fluid.layers.sigmoid(gate), fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def add_input(self, x, skip=None, conditioner=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
skip: shape(B, num_filters, 1, time_steps), skip connection
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
x: shape(B, num_filters, 1, time_steps), where time_steps = 1
skip: skip connection, same shape as x
"""
residual = x
# add step input and produce step output
x = self.conv.add_input(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(
fluid.layers.sigmoid(gate), fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def Conv2DTranspose(name_scope,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
use_cudnn=True,
act=None,
dtype="float32"):
val = 1.0 / (filter_size[0] * filter_size[1])
weight_init = fluid.initializer.ConstantInitializer(val)
weight_attr = fluid.ParamAttr(initializer=weight_init)
layer = weight_norm.Conv2DTranspose(
name_scope,
num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
dilation=dilation,
param_attr=weight_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer

View File

@ -66,12 +66,17 @@ class ScaledDotProductAttention(dg.Layer):
Scaled Dot Product Attention. Scaled Dot Product Attention.
Args: Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention. key (Variable): The input key of scaled dot product attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention. Shape: (B, T, C), dtype: float32.
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention. value (Variable): The input value of scaled dot product attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key. Shape: (B, T, C), dtype: float32.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query. query (Variable): The input query of scaled dot product attention.
dropout (Constant): dtype: float32. The probability of dropout. Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape(B, T_q, T_k), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape(B, T_q, T_q), dtype: float32.
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
Returns: Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention. result (Variable), Shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key. attention (Variable), Shape(n_head * B, T, C), the attention of key.
@ -131,14 +136,19 @@ class MultiheadAttention(dg.Layer):
Multihead Attention. Multihead Attention.
Args: Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention. key (Variable): The input key of attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention. Shape: (B, T, C), dtype: float32.
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention. value (Variable): The input value of attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key. Shape: (B, T, C), dtype: float32.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query. query_input (Variable): The input query of attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
Returns: Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention. result (Variable), the result of mutihead attention. Shape: (B, T, C).
attention (Variable), Shape(n_head * B, T, C), the attention of key. attention (Variable), the attention of key and query. Shape: (num_head * B, T, C)
""" """
batch_size = key.shape[0] batch_size = key.shape[0]
@ -146,7 +156,6 @@ class MultiheadAttention(dg.Layer):
seq_len_query = query_input.shape[1] seq_len_query = query_input.shape[1]
# Make multihead attention # Make multihead attention
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
key = layers.reshape( key = layers.reshape(
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k]) self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
value = layers.reshape( value = layers.reshape(
@ -168,18 +177,6 @@ class MultiheadAttention(dg.Layer):
result, attention = self.scal_attn( result, attention = self.scal_attn(
key, value, query, mask=mask, query_mask=query_mask) key, value, query, mask=mask, query_mask=query_mask)
key = layers.reshape(
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
value = layers.reshape(
layers.transpose(value, [2, 0, 1, 3]),
[-1, seq_len_key, self.d_k])
query = layers.reshape(
layers.transpose(query, [2, 0, 1, 3]),
[-1, seq_len_query, self.d_q])
result, attention = self.scal_attn(
key, value, query, mask=mask, query_mask=query_mask)
# concat all multihead result # concat all multihead result
result = layers.reshape( result = layers.reshape(
result, [self.num_head, batch_size, seq_len_query, self.d_q]) result, [self.num_head, batch_size, seq_len_query, self.d_q])