Modified data.py to generate masks as models inputs
This commit is contained in:
parent
25883dcd3e
commit
078d22e51c
|
@ -55,7 +55,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
|
|||
--config_path='config/fastspeech.yaml' \
|
||||
```
|
||||
|
||||
if you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--fastspeech_step``
|
||||
If you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--fastspeech_step``
|
||||
|
||||
For more help on arguments:
|
||||
``python train.py --help``.
|
||||
|
|
|
@ -3,8 +3,8 @@ audio:
|
|||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
|
@ -17,6 +17,9 @@ def add_config_options_to_parser(parser):
|
|||
help="use gpu or not during training.")
|
||||
parser.add_argument('--use_data_parallel', type=int, default=0,
|
||||
help="use data parallel or not during training.")
|
||||
parser.add_argument('--alpha', type=float, default=1.0,
|
||||
help="The hyperparameter to determine the length of the expanded sequence \
|
||||
mel, thereby controlling the voice speed.")
|
||||
|
||||
parser.add_argument('--data_path', type=str, default='./dataset/LJSpeech-1.1',
|
||||
help="the path of dataset.")
|
||||
|
|
|
@ -11,6 +11,7 @@ import paddle.fluid.dygraph as dg
|
|||
from parakeet.g2p.en import text_to_sequence
|
||||
from parakeet import audio
|
||||
from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
|
@ -41,11 +42,22 @@ def synthesis(text_input, args):
|
|||
model.eval()
|
||||
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text),[0])
|
||||
text = np.expand_dims(text, axis=0)
|
||||
pos_text = np.arange(1, text.shape[1]+1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text),[0])
|
||||
pos_text = np.expand_dims(pos_text, axis=0)
|
||||
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
|
||||
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text, text).astype(np.float32)
|
||||
|
||||
text = dg.to_variable(text)
|
||||
pos_text = dg.to_variable(pos_text)
|
||||
enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
|
||||
enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
|
||||
|
||||
mel_output, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
|
||||
mel_output, mel_output_postnet = model(text, pos_text, alpha=args.alpha,
|
||||
enc_non_pad_mask=enc_non_pad_mask,
|
||||
enc_slf_attn_mask=enc_slf_attn_mask,
|
||||
dec_non_pad_mask=None,
|
||||
dec_slf_attn_mask=None)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
|
|
|
@ -8,6 +8,7 @@ from parse import add_config_options_to_parser
|
|||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from tqdm import tqdm
|
||||
from matplotlib import cm
|
||||
from collections import OrderedDict
|
||||
from tensorboardX import SummaryWriter
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
@ -77,18 +78,32 @@ def main(args):
|
|||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d'%epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens = data
|
||||
(character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens,
|
||||
enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data
|
||||
|
||||
_, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel)
|
||||
alignment = dg.to_variable(get_alignment(attn_probs, mel_lens, cfg['transformer_head'])).astype(np.float32)
|
||||
_, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel,
|
||||
dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask, enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask, dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
alignment, max_attn = get_alignment(attn_probs, mel_lens, cfg['transformer_head'])
|
||||
alignment = dg.to_variable(alignment).astype(np.float32)
|
||||
|
||||
if local_rank==0 and global_step % 5 == 1:
|
||||
x = np.uint8(cm.viridis(max_attn[8,:mel_lens.numpy()[8]]) * 255)
|
||||
writer.add_image('Attention_%d_0'%global_step, x, 0, dataformats="HWC")
|
||||
|
||||
global_step += 1
|
||||
|
||||
#Forward
|
||||
result= model(character,
|
||||
pos_text,
|
||||
mel_pos=pos_mel,
|
||||
length_target=alignment)
|
||||
length_target=alignment,
|
||||
enc_non_pad_mask=enc_query_mask,
|
||||
enc_slf_attn_mask=enc_slf_mask,
|
||||
dec_non_pad_mask=dec_query_slf_mask,
|
||||
dec_slf_attn_mask=dec_slf_mask)
|
||||
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
|
||||
mel_loss = layers.mse_loss(mel_output, mel)
|
||||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --fastspeech_step
|
||||
CUDA_VISIBLE_DEVICES=0\
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python -u train.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# TransformerTTS
|
||||
Paddle fluid implementation of TransformerTTS, a neural TTS with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895).
|
||||
PaddlePaddle fluid implementation of TransformerTTS, a neural TTS with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895).
|
||||
|
||||
## Dataset
|
||||
|
||||
|
@ -48,7 +48,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
|
|||
--config_path='config/train_transformer.yaml' \
|
||||
```
|
||||
|
||||
if you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--transformer_step``
|
||||
If you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--transformer_step``
|
||||
|
||||
For more help on arguments:
|
||||
``python train_transformer.py --help``.
|
||||
|
@ -76,7 +76,7 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
|
|||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_vocoder.yaml' \
|
||||
```
|
||||
if you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--vocoder_step``
|
||||
If you wish to resume from an exists model, please set ``--checkpoint_path`` and ``--vocoder_step``
|
||||
|
||||
For more help on arguments:
|
||||
``python train_vocoder.py --help``.
|
||||
|
|
|
@ -8,4 +8,7 @@ audio:
|
|||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
outputs_per_step: 1
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
|
@ -10,7 +10,8 @@ from parakeet import audio
|
|||
from parakeet.data.sampler import *
|
||||
from parakeet.data.datacargo import DataCargo
|
||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
||||
from parakeet.data.dataset import DatasetMixin, TransformDataset
|
||||
from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
|
||||
class LJSpeechLoader:
|
||||
def __init__(self, config, args, nranks, rank, is_vocoder=False, shuffle=True):
|
||||
|
@ -20,6 +21,8 @@ class LJSpeechLoader:
|
|||
metadata = LJSpeechMetaData(LJSPEECH_ROOT)
|
||||
transformer = LJSpeech(config)
|
||||
dataset = TransformDataset(metadata, transformer)
|
||||
dataset = CacheDataset(dataset)
|
||||
|
||||
sampler = DistributedSampler(len(metadata), nranks, rank, shuffle=shuffle)
|
||||
|
||||
assert args.batch_size % nranks == 0
|
||||
|
@ -132,7 +135,15 @@ def batch_examples(batch):
|
|||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1)) #(B,T,num_mels)
|
||||
mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))#(B,T,num_mels)
|
||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens), np.array(mel_lens))
|
||||
enc_slf_mask = get_attn_key_pad_mask(pos_texts, texts).astype(np.float32)
|
||||
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
|
||||
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,mel_inputs).astype(np.float32)
|
||||
enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:,:,0], mel_inputs).astype(np.float32)
|
||||
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
|
||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens), np.array(mel_lens),
|
||||
enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask)
|
||||
|
||||
def batch_examples_vocoder(batch):
|
||||
mels=[]
|
||||
|
|
|
@ -3,6 +3,7 @@ from scipy.io.wavfile import write
|
|||
from parakeet.g2p.en import text_to_sequence
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from matplotlib import cm
|
||||
from tensorboardX import SummaryWriter
|
||||
from ruamel import yaml
|
||||
import paddle.fluid as fluid
|
||||
|
@ -12,6 +13,7 @@ import argparse
|
|||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from collections import OrderedDict
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet import audio
|
||||
from parakeet.models.transformer_tts.vocoder import Vocoder
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
|
@ -55,15 +57,17 @@ def synthesis(text_input, args):
|
|||
mel_input = dg.to_variable(np.zeros([1,1,80])).astype(np.float32)
|
||||
pos_text = np.arange(1, text.shape[1]+1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text),[0])
|
||||
|
||||
|
||||
|
||||
pbar = tqdm(range(args.max_len))
|
||||
|
||||
for i in pbar:
|
||||
dec_slf_mask = get_triu_tensor(mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
||||
dec_slf_mask = fluid.layers.cast(dg.to_variable(dec_slf_mask == 0), np.float32)
|
||||
pos_mel = np.arange(1, mel_input.shape[1]+1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel)
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
||||
mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1)
|
||||
|
||||
mag_pred = model_vocoder(postnet_pred)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
|
@ -87,6 +91,21 @@ def synthesis(text_input, args):
|
|||
sound_norm=False)
|
||||
|
||||
wav = _ljspeech_processor.inv_spectrogram(fluid.layers.transpose(fluid.layers.squeeze(mag_pred,[0]), [1,0]).numpy())
|
||||
global_step = 0
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image('Attention_%d_0'%global_step, x, i*4+j, dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
|
||||
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
|
||||
if not os.path.exists(args.sample_path):
|
||||
os.mkdir(args.sample_path)
|
||||
|
@ -97,4 +116,4 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
synthesis("Transformer model is so fast!", args)
|
||||
synthesis("They emphasized the necessity that the information now being furnished be handled with judgment and care.", args)
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
# train model
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u synthesis.py \
|
||||
--max_len=50 \
|
||||
--max_len=600 \
|
||||
--transformer_step=160000 \
|
||||
--vocoder_step=70000 \
|
||||
--use_gpu=1
|
||||
--vocoder_step=90000 \
|
||||
--use_gpu=1 \
|
||||
--checkpoint_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--sample_path='./sample' \
|
||||
|
|
|
@ -69,19 +69,24 @@ def main(args):
|
|||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d'%epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, _ = data
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask= data
|
||||
|
||||
global_step += 1
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel)
|
||||
|
||||
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel, dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask, enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask, dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
|
||||
|
||||
mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
||||
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||
loss = mel_loss + post_mel_loss
|
||||
|
||||
|
||||
# Note: When used stop token loss the learning did not work.
|
||||
if args.stop_token:
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
stop_loss = cross_entropy(stop_preds, label)
|
||||
loss = loss + stop_loss
|
||||
|
||||
|
@ -123,6 +128,7 @@ def main(args):
|
|||
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
|
||||
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
|
||||
|
||||
|
||||
if args.use_data_parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
|
@ -132,6 +138,7 @@ def main(args):
|
|||
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg['grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
|
||||
|
||||
# save checkpoint
|
||||
if local_rank==0 and global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --transformer_step
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
export CUDA_VISIBLE_DEVICES=2
|
||||
python -u train_transformer.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import six
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class DatasetMixin(object):
|
||||
|
@ -45,6 +46,17 @@ class TransformDataset(DatasetMixin):
|
|||
in_data = self._dataset[i]
|
||||
return self._transform(in_data)
|
||||
|
||||
class CacheDataset(DatasetMixin):
|
||||
def __init__(self, dataset):
|
||||
self._dataset = dataset
|
||||
pbar = tqdm(range(len(self._dataset)))
|
||||
self._cache = [self._dataset[i] for i in pbar]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def get_example(self, i):
|
||||
return self._cache[i]
|
||||
|
||||
class TupleDataset(object):
|
||||
def __init__(self, *datasets):
|
||||
|
|
|
@ -18,6 +18,7 @@ class Decoder(dg.Layer):
|
|||
super(Decoder, self).__init__()
|
||||
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
self.pos_inp = get_sinusoid_encoding_table(n_position, d_model, padding_idx=0)
|
||||
self.position_enc = dg.Embedding(size=[n_position, d_model],
|
||||
padding_idx=0,
|
||||
|
@ -28,7 +29,7 @@ class Decoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, enc_seq, enc_pos):
|
||||
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Decoder layer of FastSpeech.
|
||||
|
||||
|
@ -42,10 +43,7 @@ class Decoder(dg.Layer):
|
|||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
|
||||
# -- Prepare masks
|
||||
slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
|
||||
non_pad_mask = get_non_pad_mask(enc_pos)
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
||||
# -- Forward
|
||||
dec_output = enc_seq + self.position_enc(enc_pos)
|
||||
|
|
|
@ -18,11 +18,12 @@ class Encoder(dg.Layer):
|
|||
dropout=0.1):
|
||||
super(Encoder, self).__init__()
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
|
||||
self.src_word_emb = dg.Embedding(size=[n_src_vocab, d_model], padding_idx=0)
|
||||
self.src_word_emb = dg.Embedding(size=[n_src_vocab, d_model], padding_idx=0,
|
||||
param_attr=fluid.initializer.Normal(loc=0.0, scale=1.0))
|
||||
self.pos_inp = get_sinusoid_encoding_table(n_position, d_model, padding_idx=0)
|
||||
self.position_enc = dg.Embedding(size=[n_position, d_model],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp),
|
||||
trainable=False))
|
||||
|
@ -30,7 +31,7 @@ class Encoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, character, text_pos):
|
||||
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Encoder layer of FastSpeech.
|
||||
|
||||
|
@ -46,10 +47,7 @@ class Encoder(dg.Layer):
|
|||
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
# -- prepare masks
|
||||
# shape character (N, T)
|
||||
slf_attn_mask = get_attn_key_pad_mask(seq_k=character, seq_q=character)
|
||||
non_pad_mask = get_non_pad_mask(character)
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
||||
# -- Forward
|
||||
enc_output = self.src_word_emb(character) + self.position_enc(text_pos) #(N, T, C)
|
||||
|
@ -61,4 +59,4 @@ class Encoder(dg.Layer):
|
|||
slf_attn_mask=slf_attn_mask)
|
||||
enc_slf_attn_list += [enc_slf_attn]
|
||||
|
||||
return enc_output, non_pad_mask, enc_slf_attn_list
|
||||
return enc_output, enc_slf_attn_list
|
|
@ -1,7 +1,9 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.g2p.text.symbols import symbols
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
||||
from parakeet.models.fastspeech.length_regulator import LengthRegulator
|
||||
from parakeet.models.fastspeech.encoder import Encoder
|
||||
|
@ -54,7 +56,9 @@ class FastSpeech(dg.Layer):
|
|||
dropout=0.1,
|
||||
batchnorm_last=True)
|
||||
|
||||
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
|
||||
def forward(self, character, text_pos, enc_non_pad_mask, dec_non_pad_mask,
|
||||
enc_slf_attn_mask=None, dec_slf_attn_mask=None,
|
||||
mel_pos=None, length_target=None, alpha=1.0):
|
||||
"""
|
||||
FastSpeech model.
|
||||
|
||||
|
@ -80,13 +84,15 @@ class FastSpeech(dg.Layer):
|
|||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(character, text_pos)
|
||||
encoder_output, enc_slf_attn_list = self.encoder(character, text_pos, enc_non_pad_mask, slf_attn_mask=enc_slf_attn_mask)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
|
||||
length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output,
|
||||
target=length_target,
|
||||
alpha=alpha)
|
||||
decoder_output, dec_slf_attn_list = self.decoder(length_regulator_output, mel_pos)
|
||||
decoder_output, dec_slf_attn_list = self.decoder(length_regulator_output, mel_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=dec_slf_attn_mask)
|
||||
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
@ -94,7 +100,12 @@ class FastSpeech(dg.Layer):
|
|||
return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list
|
||||
else:
|
||||
length_regulator_output, decoder_pos = self.length_regulator(encoder_output, alpha=alpha)
|
||||
decoder_output, _ = self.decoder(length_regulator_output, decoder_pos)
|
||||
slf_attn_mask = get_triu_tensor(decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
|
||||
slf_attn_mask = fluid.layers.cast(dg.to_variable(slf_attn_mask == 0), np.float32)
|
||||
slf_attn_mask = dg.to_variable(slf_attn_mask)
|
||||
dec_non_pad_mask = fluid.layers.unsqueeze((decoder_pos != 0).astype(np.float32), [-1])
|
||||
decoder_output, _ = self.decoder(length_regulator_output, decoder_pos, dec_non_pad_mask,
|
||||
slf_attn_mask=slf_attn_mask)
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ class FFTBlock(dg.Layer):
|
|||
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, is_bias=True, dropout=dropout, is_concat=False)
|
||||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
|
||||
def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Feed Forward Transformer block in FastSpeech.
|
||||
|
||||
|
@ -28,6 +28,7 @@ class FFTBlock(dg.Layer):
|
|||
slf_attn (Variable), Shape(B * n_head, T, T), the self attention.
|
||||
"""
|
||||
output, slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
|
||||
output *= non_pad_mask
|
||||
|
||||
output = self.pos_ffn(output)
|
||||
|
|
|
@ -121,11 +121,11 @@ class DurationPredictor(dg.Layer):
|
|||
out = layers.transpose(encoder_output, [0,2,1])
|
||||
out = self.conv1(out)
|
||||
out = layers.transpose(out, [0,2,1])
|
||||
out = layers.dropout(layers.relu(self.layer_norm1(out)), self.dropout)
|
||||
out = layers.dropout(layers.relu(self.layer_norm1(out)), self.dropout, dropout_implementation='upscale_in_train')
|
||||
out = layers.transpose(out, [0,2,1])
|
||||
out = self.conv2(out)
|
||||
out = layers.transpose(out, [0,2,1])
|
||||
out = layers.dropout(layers.relu(self.layer_norm2(out)), self.dropout)
|
||||
out = layers.dropout(layers.relu(self.layer_norm2(out)), self.dropout, dropout_implementation='upscale_in_train')
|
||||
out = layers.relu(self.linear(out))
|
||||
out = layers.squeeze(out, axes=[-1])
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ def get_alignment(attn_probs, mel_lens, n_head):
|
|||
max_F = F
|
||||
max_attn = attn
|
||||
alignment = compute_duration(max_attn, mel_lens)
|
||||
return alignment
|
||||
return alignment, max_attn
|
||||
|
||||
def score_F(attn):
|
||||
max = np.max(attn, axis=-1)
|
||||
|
|
|
@ -124,6 +124,7 @@ class CBHG(dg.Layer):
|
|||
conv_list = []
|
||||
conv_input = input_
|
||||
|
||||
|
||||
for i, (conv, batchnorm) in enumerate(zip(self.conv_list, self.batchnorm_list)):
|
||||
conv_input = self._conv_fit_dim(conv(conv_input), i+1)
|
||||
conv_input = layers.relu(batchnorm(conv_input))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.modules.utils import *
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.modules.multihead_attention import MultiheadAttention
|
||||
from parakeet.modules.ffn import PositionwiseFeedForward
|
||||
from parakeet.models.transformer_tts.prenet import PreNet
|
||||
|
@ -11,6 +11,7 @@ class Decoder(dg.Layer):
|
|||
def __init__(self, num_hidden, config, num_head=4):
|
||||
super(Decoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
param = fluid.ParamAttr()
|
||||
self.alpha = self.create_parameter(shape=(1,), attr=param, dtype='float32',
|
||||
default_initializer = fluid.initializer.ConstantInitializer(value=1.0))
|
||||
|
@ -48,25 +49,21 @@ class Decoder(dg.Layer):
|
|||
self.postconvnet = PostConvNet(config['audio']['num_mels'], config['hidden_size'],
|
||||
filter_size = 5, padding = 4, num_conv=5,
|
||||
outputs_per_step=config['audio']['outputs_per_step'],
|
||||
use_cudnn = True)
|
||||
use_cudnn=True)
|
||||
|
||||
def forward(self, key, value, query, c_mask, positional):
|
||||
def forward(self, key, value, query, positional, mask, m_mask=None, m_self_mask=None, zero_mask=None):
|
||||
|
||||
# get decoder mask with triangular matrix
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
m_mask = get_non_pad_mask(positional)
|
||||
mask = get_attn_key_pad_mask((positional==0).astype(np.float32), query)
|
||||
triu_tensor = dg.to_variable(get_triu_tensor(query.numpy(), query.numpy())).astype(np.float32)
|
||||
mask = mask + triu_tensor
|
||||
mask = fluid.layers.cast(mask == 0, np.float32)
|
||||
|
||||
# (batch_size, decoder_len, encoder_len)
|
||||
zero_mask = get_attn_key_pad_mask(layers.squeeze(c_mask,[-1]), query)
|
||||
m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]])
|
||||
m_self_mask = layers.expand(m_self_mask, [self.num_head, 1, query.shape[1]])
|
||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
||||
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
|
||||
|
||||
else:
|
||||
mask = get_triu_tensor(query.numpy(), query.numpy()).astype(np.float32)
|
||||
mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32)
|
||||
m_mask, zero_mask = None, None
|
||||
m_mask, m_self_mask, zero_mask = None, None, None
|
||||
|
||||
|
||||
# Decoder pre-network
|
||||
query = self.decoder_prenet(query)
|
||||
|
@ -79,18 +76,22 @@ class Decoder(dg.Layer):
|
|||
query = positional * self.alpha + query
|
||||
|
||||
#positional dropout
|
||||
query = fluid.layers.dropout(query, 0.1)
|
||||
query = fluid.layers.dropout(query, 0.1, dropout_implementation='upscale_in_train')
|
||||
|
||||
|
||||
# Attention decoder-decoder, encoder-decoder
|
||||
selfattn_list = list()
|
||||
attn_list = list()
|
||||
|
||||
|
||||
for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers, self.ffns):
|
||||
query, attn_dec = selfattn(query, query, query, mask = mask, query_mask = m_mask)
|
||||
query, attn_dec = selfattn(query, query, query, mask = mask, query_mask = m_self_mask)
|
||||
query, attn_dot = attn(key, value, query, mask = zero_mask, query_mask = m_mask)
|
||||
query = ffn(query)
|
||||
selfattn_list.append(attn_dec)
|
||||
attn_list.append(attn_dot)
|
||||
|
||||
|
||||
# Mel linear projection
|
||||
mel_out = self.mel_linear(query)
|
||||
# Post Mel Network
|
||||
|
|
|
@ -9,11 +9,11 @@ class Encoder(dg.Layer):
|
|||
def __init__(self, embedding_size, num_hidden, num_head=4):
|
||||
super(Encoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
param = fluid.ParamAttr(initializer=fluid.initializer.Constant(value=1.0))
|
||||
self.alpha = self.create_parameter(shape=(1, ), attr=param, dtype='float32')
|
||||
self.pos_inp = get_sinusoid_encoding_table(1024, self.num_hidden, padding_idx=0)
|
||||
self.pos_emb = dg.Embedding(size=[1024, num_hidden],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp),
|
||||
trainable=False))
|
||||
|
@ -23,17 +23,20 @@ class Encoder(dg.Layer):
|
|||
self.layers = [MultiheadAttention(num_hidden, num_hidden//num_head, num_hidden//num_head) for _ in range(3)]
|
||||
for i, layer in enumerate(self.layers):
|
||||
self.add_sublayer("self_attn_{}".format(i), layer)
|
||||
self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*num_head, filter_size=1, use_cudnn = True) for _ in range(3)]
|
||||
self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*num_head, filter_size=1, use_cudnn=True) for _ in range(3)]
|
||||
for i, layer in enumerate(self.ffns):
|
||||
self.add_sublayer("ffns_{}".format(i), layer)
|
||||
|
||||
def forward(self, x, positional):
|
||||
def forward(self, x, positional, mask=None, query_mask=None):
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
query_mask = get_non_pad_mask(positional)
|
||||
mask = get_attn_key_pad_mask(positional, x)
|
||||
seq_len_key = x.shape[1]
|
||||
query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key])
|
||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
||||
else:
|
||||
query_mask, mask = None, None
|
||||
|
||||
|
||||
# Encoder pre_network
|
||||
x = self.encoder_prenet(x) #(N,T,C)
|
||||
|
||||
|
@ -43,9 +46,9 @@ class Encoder(dg.Layer):
|
|||
|
||||
x = positional * self.alpha + x #(N, T, C)
|
||||
|
||||
|
||||
|
||||
# Positional dropout
|
||||
x = layers.dropout(x, 0.1)
|
||||
x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train')
|
||||
|
||||
# Self attention encoder
|
||||
attentions = list()
|
||||
|
@ -54,4 +57,4 @@ class Encoder(dg.Layer):
|
|||
x = ffn(x)
|
||||
attentions.append(attention)
|
||||
|
||||
return x, query_mask, attentions
|
||||
return x, attentions
|
|
@ -14,7 +14,8 @@ class EncoderPrenet(dg.Layer):
|
|||
self.num_hidden = num_hidden
|
||||
self.use_cudnn = use_cudnn
|
||||
self.embedding = dg.Embedding( size = [len(symbols), embedding_size],
|
||||
padding_idx = None)
|
||||
padding_idx = 0,
|
||||
param_attr=fluid.initializer.Normal(loc=0.0, scale=1.0))
|
||||
self.conv_list = []
|
||||
k = math.sqrt(1 / embedding_size)
|
||||
self.conv_list.append(Conv1D(num_channels = embedding_size,
|
||||
|
@ -49,10 +50,12 @@ class EncoderPrenet(dg.Layer):
|
|||
bias_attr=fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k)))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.embedding(x) #(batch_size, seq_len, embending_size)
|
||||
x = layers.transpose(x,[0,2,1])
|
||||
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
|
||||
x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2)
|
||||
x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2,
|
||||
dropout_implementation='upscale_in_train')
|
||||
x = layers.transpose(x,[0,2,1]) #(N,T,C)
|
||||
x = self.projection(x)
|
||||
|
||||
|
|
|
@ -76,11 +76,13 @@ class PostConvNet(dg.Layer):
|
|||
batch_norm = self.batch_norm_list[i]
|
||||
conv = self.conv_list[i]
|
||||
|
||||
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
|
||||
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
conv = self.conv_list[self.num_conv-1]
|
||||
input = conv(input)[:,:,:len]
|
||||
if self.batchnorm_last:
|
||||
batch_norm = self.batch_norm_list[self.num_conv-1]
|
||||
input = layers.dropout(batch_norm(input), self.dropout)
|
||||
input = layers.dropout(batch_norm(input), self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
output = layers.transpose(input, [0,2,1])
|
||||
return output
|
|
@ -34,6 +34,6 @@ class PreNet(dg.Layer):
|
|||
Returns:
|
||||
x (Variable), Shape(B, T, C), the result after pernet.
|
||||
"""
|
||||
x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate)
|
||||
x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate)
|
||||
x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate, dropout_implementation='upscale_in_train')
|
||||
x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate, dropout_implementation='upscale_in_train')
|
||||
return x
|
||||
|
|
|
@ -10,15 +10,14 @@ class TransformerTTS(dg.Layer):
|
|||
self.decoder = Decoder(config['hidden_size'], config)
|
||||
self.config = config
|
||||
|
||||
def forward(self, characters, mel_input, pos_text, pos_mel):
|
||||
|
||||
key, c_mask, attns_enc = self.encoder(characters, pos_text)
|
||||
|
||||
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(key, key, mel_input, c_mask, pos_mel)
|
||||
|
||||
def forward(self, characters, mel_input, pos_text, pos_mel, dec_slf_mask, enc_slf_mask=None, enc_query_mask=None, enc_dec_mask=None, dec_query_slf_mask=None, dec_query_mask=None):
|
||||
key, attns_enc = self.encoder(characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
||||
|
||||
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(key, key, mel_input, pos_mel,
|
||||
mask=dec_slf_mask, zero_mask=enc_dec_mask,
|
||||
m_self_mask=dec_query_slf_mask, m_mask=dec_query_mask )
|
||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,9 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
|||
return sinusoid_table
|
||||
|
||||
def get_non_pad_mask(seq):
|
||||
return layers.unsqueeze((seq != 0).astype(np.float32),[-1])
|
||||
mask = (seq != 0).astype(np.float32)
|
||||
mask = np.expand_dims(mask, axis=-1)
|
||||
return mask
|
||||
|
||||
def get_attn_key_pad_mask(seq_k, seq_q):
|
||||
''' For masking out the padding part of key sequence. '''
|
||||
|
@ -43,7 +45,21 @@ def get_attn_key_pad_mask(seq_k, seq_q):
|
|||
# Expand to fit the shape of key query attention matrix.
|
||||
len_q = seq_q.shape[1]
|
||||
padding_mask = (seq_k != 0).astype(np.float32)
|
||||
padding_mask = layers.expand(layers.unsqueeze(padding_mask,[1]), [1, len_q, 1])
|
||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
||||
padding_mask = padding_mask.repeat([len_q],axis=1)
|
||||
padding_mask = (padding_mask == 0).astype(np.float32) * (-2 ** 32 + 1)
|
||||
return padding_mask
|
||||
|
||||
def get_dec_attn_key_pad_mask(seq_k, seq_q):
|
||||
''' For masking out the padding part of key sequence. '''
|
||||
|
||||
# Expand to fit the shape of key query attention matrix.
|
||||
len_q = seq_q.shape[1]
|
||||
padding_mask = (seq_k == 0).astype(np.float32)
|
||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
||||
triu_tensor = get_triu_tensor(seq_q, seq_q)
|
||||
padding_mask = padding_mask.repeat([len_q],axis=1) + triu_tensor
|
||||
padding_mask = (padding_mask != 0).astype(np.float32) * (-2 ** 32 + 1)
|
||||
return padding_mask
|
||||
|
||||
def get_triu_tensor(seq_k, seq_q):
|
||||
|
|
|
@ -40,10 +40,10 @@ class DynamicGRU(dg.Layer):
|
|||
i = inputs.shape[1] - 1 - i
|
||||
input_ = inputs[:, i:i + 1, :]
|
||||
input_ = layers.reshape(
|
||||
input_, [-1, input_.shape[2]], inplace=False)
|
||||
input_, [-1, input_.shape[2]])
|
||||
hidden, reset, gate = self.gru_unit(input_, hidden)
|
||||
hidden_ = layers.reshape(
|
||||
hidden, [-1, 1, hidden.shape[1]], inplace=False)
|
||||
hidden, [-1, 1, hidden.shape[1]])
|
||||
res.append(hidden_)
|
||||
if self.is_reverse:
|
||||
res = res[::-1]
|
||||
|
|
|
@ -45,7 +45,7 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
x = self.w_2(layers.relu(self.w_1(x)))
|
||||
|
||||
# dropout
|
||||
x = layers.dropout(x, self.dropout)
|
||||
x = layers.dropout(x, self.dropout, dropout_implementation='upscale_in_train')
|
||||
|
||||
x = layers.transpose(x, [0,2,1])
|
||||
# residual connection
|
||||
|
|
|
@ -211,7 +211,7 @@ class Conv1DGLU(dg.Layer):
|
|||
|
||||
residual = x
|
||||
x = fluid.layers.dropout(
|
||||
x, self.dropout, dropout_implementation="upscale_in_train")
|
||||
x, self.dropout)
|
||||
x = self.conv(x)
|
||||
|
||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
||||
|
@ -241,7 +241,7 @@ class Conv1DGLU(dg.Layer):
|
|||
|
||||
# add step input and produce step output
|
||||
x = fluid.layers.dropout(
|
||||
x, self.dropout, dropout_implementation="upscale_in_train")
|
||||
x, self.dropout)
|
||||
x = self.conv.add_input(x)
|
||||
|
||||
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
|
||||
|
|
|
@ -47,17 +47,13 @@ class ScaledDotProductAttention(dg.Layer):
|
|||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
# Compute attention score
|
||||
attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y
|
||||
attention = attention / math.sqrt(self.d_key)
|
||||
attention = layers.matmul(query, key, transpose_y=True, alpha=self.d_key**-0.5) #transpose the last dim in y
|
||||
|
||||
# Mask key to ignore padding
|
||||
if mask is not None:
|
||||
attention = attention * mask
|
||||
mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1)
|
||||
attention = attention + mask
|
||||
|
||||
attention = layers.softmax(attention)
|
||||
attention = layers.dropout(attention, dropout)
|
||||
attention = layers.dropout(attention, dropout, dropout_implementation='upscale_in_train')
|
||||
|
||||
# Mask query to ignore padding
|
||||
if query_mask is not None:
|
||||
|
@ -103,15 +99,10 @@ class MultiheadAttention(dg.Layer):
|
|||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
|
||||
batch_size = key.shape[0]
|
||||
seq_len_key = key.shape[1]
|
||||
seq_len_query = query_input.shape[1]
|
||||
|
||||
# repeat masks h times
|
||||
if query_mask is not None:
|
||||
query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key])
|
||||
if mask is not None:
|
||||
mask = layers.expand(mask, (self.num_head, 1, 1))
|
||||
|
||||
|
||||
# Make multihead attention
|
||||
|
@ -123,15 +114,15 @@ class MultiheadAttention(dg.Layer):
|
|||
key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
||||
value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
||||
query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q])
|
||||
|
||||
|
||||
result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask)
|
||||
|
||||
|
||||
# concat all multihead result
|
||||
result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
||||
result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
|
||||
if self.is_concat:
|
||||
result = layers.concat([query_input,result], axis=-1)
|
||||
result = layers.dropout(self.fc(result), self.dropout)
|
||||
result = layers.dropout(self.fc(result), self.dropout, dropout_implementation='upscale_in_train')
|
||||
result = result + query_input
|
||||
|
||||
result = self.layer_norm(result)
|
||||
|
|
Loading…
Reference in New Issue