rebuild code and TransformerTTS is right. FastSpeech will later.

This commit is contained in:
lifuchen 2020-01-08 03:55:06 +00:00 committed by chenfeiyu
parent 5b632d18af
commit e0aef2e081
21 changed files with 297 additions and 250 deletions

View File

@ -9,7 +9,7 @@ class AudioProcessor(object):
sample_rate=None, # int, sampling rate sample_rate=None, # int, sampling rate
num_mels=None, # int, bands of mel spectrogram num_mels=None, # int, bands of mel spectrogram
min_level_db=None, # float, minimum level db min_level_db=None, # float, minimum level db
ref_level_db=None, # float, reference level dbn ref_level_db=None, # float, reference level db
n_fft=None, # int: number of samples in a frame for stft n_fft=None, # int: number of samples in a frame for stft
win_length=None, # int: the same meaning with n_fft win_length=None, # int: the same meaning with n_fft
hop_length=None, # int: number of samples between neighboring frame hop_length=None, # int: number of samples between neighboring frame
@ -22,7 +22,7 @@ class AudioProcessor(object):
mel_fmax=None, # int: mel spectrogram's maximum frequency mel_fmax=None, # int: mel spectrogram's maximum frequency
clip_norm=True, # bool: clip spectrogram's norm clip_norm=True, # bool: clip spectrogram's norm
griffin_lim_iters=None, # int: griffin_lim_iters=None, # int:
do_trim_silence=False, # bool: trim silience do_trim_silence=False, # bool: trim silence
sound_norm=False, sound_norm=False,
**kwargs): **kwargs):
self.sample_rate = sample_rate self.sample_rate = sample_rate

View File

@ -12,19 +12,19 @@ from parakeet.data.dataset import Dataset
from parakeet.data.batch import TextIDBatcher, SpecBatcher from parakeet.data.batch import TextIDBatcher, SpecBatcher
class LJSpeechLoader: class LJSpeechLoader:
def __init__(self, config, nranks, rank, is_vocoder=False): def __init__(self, config, nranks, rank, is_vocoder=False, shuffle=True):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
LJSPEECH_ROOT = Path(config.data_path) LJSPEECH_ROOT = Path(config.data_path)
dataset = LJSpeech(LJSPEECH_ROOT, config) dataset = LJSpeech(LJSPEECH_ROOT, config)
sampler = DistributedSampler(len(dataset), nranks, rank) sampler = DistributedSampler(len(dataset), nranks, rank, shuffle=shuffle)
assert config.batch_size % nranks == 0 assert config.batch_size % nranks == 0
each_bs = config.batch_size // nranks each_bs = config.batch_size // nranks
if is_vocoder: if is_vocoder:
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_vocoder, drop_last=True) dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=shuffle, collate_fn=batch_examples_vocoder, drop_last=True)
else: else:
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True) dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=shuffle, collate_fn=batch_examples, drop_last=True)
self.reader = fluid.io.DataLoader.from_generator( self.reader = fluid.io.DataLoader.from_generator(
capacity=32, capacity=32,
@ -41,6 +41,25 @@ class LJSpeech(Dataset):
self.root = root if isinstance(root, Path) else Path(root) self.root = root if isinstance(root, Path) else Path(root)
self.metadata = self._prepare_metadata() self.metadata = self._prepare_metadata()
self.config = config self.config = config
self._ljspeech_processor = audio.AudioProcessor(
sample_rate=config.audio.sr,
num_mels=config.audio.num_mels,
min_level_db=config.audio.min_level_db,
ref_level_db=config.audio.ref_level_db,
n_fft=config.audio.n_fft,
win_length= config.audio.win_length,
hop_length= config.audio.hop_length,
power=config.audio.power,
preemphasis=config.audio.preemphasis,
signal_norm=True,
symmetric_norm=False,
max_norm=1.,
mel_fmin=0,
mel_fmax=None,
clip_norm=True,
griffin_lim_iters=60,
do_trim_silence=False,
sound_norm=False)
def _prepare_metadata(self): def _prepare_metadata(self):
csv_path = self.root.joinpath("metadata.csv") csv_path = self.root.joinpath("metadata.csv")
@ -59,29 +78,10 @@ class LJSpeech(Dataset):
fname, raw_text, normalized_text = metadatum fname, raw_text, normalized_text = metadatum
wav_path = self.root.joinpath("wavs", fname + ".wav") wav_path = self.root.joinpath("wavs", fname + ".wav")
_ljspeech_processor = audio.AudioProcessor(
sample_rate=22050,
num_mels=80,
min_level_db=-100,
ref_level_db=20,
n_fft=2048,
win_length= int(22050 * 0.05),
hop_length= int(22050 * 0.0125),
power=1.2,
preemphasis=0.97,
signal_norm=True,
symmetric_norm=False,
max_norm=1.,
mel_fmin=0,
mel_fmax=None,
clip_norm=True,
griffin_lim_iters=60,
do_trim_silence=False,
sound_norm=False)
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize # load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
wav = _ljspeech_processor.load_wav(str(wav_path)) wav = self._ljspeech_processor.load_wav(str(wav_path))
mag = _ljspeech_processor.spectrogram(wav).astype(np.float32) mag = self._ljspeech_processor.spectrogram(wav).astype(np.float32)
mel = _ljspeech_processor.melspectrogram(wav).astype(np.float32) mel = self._ljspeech_processor.melspectrogram(wav).astype(np.float32)
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64) phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
@ -123,11 +123,11 @@ def batch_examples(batch):
text_lens = sorted(text_lens, reverse=True) text_lens = sorted(text_lens, reverse=True)
# Pad sequence with largest len of the batch # Pad sequence with largest len of the batch
texts = TextIDBatcher(pad_id=0)(texts) texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1)) mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1)) #(B,T,num_mels)
mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1)) mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))#(B,T,num_mels)
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens)) return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens))
def batch_examples_vocoder(batch): def batch_examples_vocoder(batch):

View File

@ -1,124 +0,0 @@
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import os
import hparams
import Audio
from text import text_to_sequence
from utils import process_text, pad_1D, pad_2D
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class FastSpeechDataset(Dataset):
""" LJSpeech """
def __init__(self):
self.text = process_text(os.path.join("data", "train.txt"))
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
mel_gt_name = os.path.join(
hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (idx+1))
mel_gt_target = np.load(mel_gt_name)
D = np.load(os.path.join(hparams.alignment_path, str(idx)+".npy"))
character = self.text[idx][0:len(self.text[idx])-1]
character = np.array(text_to_sequence(
character, hparams.text_cleaners))
sample = {"text": character,
"mel_target": mel_gt_target,
"D": D}
return sample
def reprocess(batch, cut_list):
texts = [batch[ind]["text"] for ind in cut_list]
mel_targets = [batch[ind]["mel_target"] for ind in cut_list]
Ds = [batch[ind]["D"] for ind in cut_list]
length_text = np.array([])
for text in texts:
length_text = np.append(length_text, text.shape[0])
src_pos = list()
max_len = int(max(length_text))
for length_src_row in length_text:
src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
(0, max_len-int(length_src_row)), 'constant'))
src_pos = np.array(src_pos)
length_mel = np.array(list())
for mel in mel_targets:
length_mel = np.append(length_mel, mel.shape[0])
mel_pos = list()
max_mel_len = int(max(length_mel))
for length_mel_row in length_mel:
mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
(0, max_mel_len-int(length_mel_row)), 'constant'))
mel_pos = np.array(mel_pos)
texts = pad_1D(texts)
Ds = pad_1D(Ds)
mel_targets = pad_2D(mel_targets)
out = {"text": texts,
"mel_target": mel_targets,
"D": Ds,
"mel_pos": mel_pos,
"src_pos": src_pos,
"mel_max_len": max_mel_len}
return out
def collate_fn(batch):
len_arr = np.array([d["text"].shape[0] for d in batch])
index_arr = np.argsort(-len_arr)
batchsize = len(batch)
real_batchsize = int(math.sqrt(batchsize))
cut_list = list()
for i in range(real_batchsize):
cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])
output = list()
for i in range(real_batchsize):
output.append(reprocess(batch, cut_list[i]))
return output
if __name__ == "__main__":
# Test
dataset = FastSpeechDataset()
training_loader = DataLoader(dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn,
drop_last=True,
num_workers=0)
total_step = hparams.epochs * len(training_loader) * hparams.batch_size
cnt = 0
for i, batchs in enumerate(training_loader):
for j, data_of_batch in enumerate(batchs):
mel_target = torch.from_numpy(
data_of_batch["mel_target"]).float().to(device)
D = torch.from_numpy(data_of_batch["D"]).int().to(device)
# print(mel_target.size())
# print(D.sum())
print(cnt)
if mel_target.size(1) == D.sum().item():
cnt += 1
print(cnt)

View File

@ -11,20 +11,33 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward
class FFTBlock(dg.Layer): class FFTBlock(dg.Layer):
"""FFT Block"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2): def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2):
super(FFTBlock, self).__init__() super(FFTBlock, self).__init__()
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout) self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) """
enc_output *= non_pad_mask Feed Forward Transformer block in FastSpeech.
Args:
enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input.
T means the timesteps of input.
non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence.
slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention.
len_q means the sequence length of query, len_k means the sequence length of key.
enc_output = self.pos_ffn(enc_output) Returns:
enc_output *= non_pad_mask output (Variable), Shape(B, T, C), the output after self-attention & ffn.
slf_attn (Variable), Shape(B * n_head, T, T), the self attention.
"""
output, slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
output *= non_pad_mask
return enc_output, enc_slf_attn output = self.pos_ffn(output)
output *= non_pad_mask
return output, slf_attn
class LengthRegulator(dg.Layer): class LengthRegulator(dg.Layer):
@ -70,6 +83,20 @@ class LengthRegulator(dg.Layer):
def forward(self, x, alpha=1.0, target=None): def forward(self, x, alpha=1.0, target=None):
"""
Length Regulator block in FastSpeech.
Args:
x (Variable): Shape(B, T, C), dtype: float32. The encoder output.
alpha (Constant): dtype: float32. The hyperparameter to determine the length of
the expanded sequence mel, thereby controlling the voice speed.
target (Variable): (Variable, optional): Shape(B, T_text),
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
Returns:
output (Variable), Shape(B, T, C), the output after exppand.
duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor.
"""
duration_predictor_output = self.duration_predictor(x) duration_predictor_output = self.duration_predictor(x)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
output = self.LR(x, target) output = self.LR(x, target)
@ -81,7 +108,6 @@ class LengthRegulator(dg.Layer):
return output, mel_pos return output, mel_pos
class DurationPredictor(dg.Layer): class DurationPredictor(dg.Layer):
""" Duration Predictor """
def __init__(self, input_size, out_channels, filter_size, dropout=0.1): def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
super(DurationPredictor, self).__init__() super(DurationPredictor, self).__init__()
self.input_size = input_size self.input_size = input_size
@ -105,7 +131,14 @@ class DurationPredictor(dg.Layer):
self.linear =dg.Linear(self.out_channels, 1) self.linear =dg.Linear(self.out_channels, 1)
def forward(self, encoder_output): def forward(self, encoder_output):
"""
Duration Predictor block in FastSpeech.
Args:
encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output.
Returns:
out (Variable), Shape(B, T, C), the output of duration predictor.
"""
# encoder_output.shape(N, T, C) # encoder_output.shape(N, T, C)
out = layers.dropout(layers.relu(self.layer_norm1(self.conv1(encoder_output))), self.dropout) out = layers.dropout(layers.relu(self.layer_norm1(self.conv1(encoder_output))), self.dropout)
out = layers.dropout(layers.relu(self.layer_norm2(self.conv2(out))), self.dropout) out = layers.dropout(layers.relu(self.layer_norm2(self.conv2(out))), self.dropout)

View File

@ -35,6 +35,20 @@ class Encoder(dg.Layer):
self.add_sublayer('fft_{}'.format(i), layer) self.add_sublayer('fft_{}'.format(i), layer)
def forward(self, character, text_pos): def forward(self, character, text_pos):
"""
Encoder layer of FastSpeech.
Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text
characters. T_text means the timesteps of input characters.
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
position. T_text means the timesteps of input characters.
Returns:
enc_output (Variable), Shape(B, text_T, C), the encoder output.
non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad.
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list.
"""
enc_slf_attn_list = [] enc_slf_attn_list = []
# -- prepare masks # -- prepare masks
# shape character (N, T) # shape character (N, T)
@ -80,6 +94,18 @@ class Decoder(dg.Layer):
self.add_sublayer('fft_{}'.format(i), layer) self.add_sublayer('fft_{}'.format(i), layer)
def forward(self, enc_seq, enc_pos): def forward(self, enc_seq, enc_pos):
"""
Decoder layer of FastSpeech.
Args:
enc_seq (Variable), Shape(B, text_T, C), dtype: float32.
The output of length regulator.
enc_pos (Variable, optional): Shape(B, T_mel),
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
Returns:
dec_output (Variable), Shape(B, mel_T, C), the decoder output.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
"""
dec_slf_attn_list = [] dec_slf_attn_list = []
# -- Prepare masks # -- Prepare masks
@ -141,6 +167,31 @@ class FastSpeech(dg.Layer):
dropout=0.1) dropout=0.1)
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0): def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
"""
FastSpeech model.
Args:
character (Variable): Shape(B, T_text), dtype: float32. The input text
characters. T_text means the timesteps of input characters.
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
position. T_text means the timesteps of input characters.
mel_pos (Variable, optional): Shape(B, T_mel),
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
length_target (Variable, optional): Shape(B, T_text),
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
alpha (Constant):
dtype: float32. The hyperparameter to determine the length of the expanded sequence
mel, thereby controlling the voice speed.
Returns:
mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet.
mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet.
duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute
with duration predictor.
enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list.
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
"""
encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(character, text_pos) encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(character, text_pos)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:

View File

@ -9,9 +9,9 @@ def add_config_options_to_parser(parser):
help="the sampling rate of audio data file.") help="the sampling rate of audio data file.")
parser.add_argument('--audio.preemphasis', type=float, default=0.97, parser.add_argument('--audio.preemphasis', type=float, default=0.97,
help="the preemphasis coefficient.") help="the preemphasis coefficient.")
parser.add_argument('--audio.hop_length', type=float, default=128, parser.add_argument('--audio.hop_length', type=int, default=128,
help="the number of samples to advance between frames.") help="the number of samples to advance between frames.")
parser.add_argument('--audio.win_length', type=float, default=1024, parser.add_argument('--audio.win_length', type=int, default=1024,
help="the length (width) of the window function.") help="the length (width) of the window function.")
parser.add_argument('--audio.power', type=float, default=1.4, parser.add_argument('--audio.power', type=float, default=1.4,
help="the power to raise before griffin-lim.") help="the power to raise before griffin-lim.")

View File

@ -66,8 +66,8 @@ def main(cfg):
model = FastSpeech(cfg) model = FastSpeech(cfg)
model.train() model.train()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step)) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
parameter_list=model.parameters())
reader = LJSpeechLoader(cfg, nranks, local_rank).reader() reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
if cfg.checkpoint_path is not None: if cfg.checkpoint_path is not None:

View File

@ -13,7 +13,8 @@ audio:
hidden_size: 256 hidden_size: 256
embedding_size: 512 embedding_size: 512
warm_up_step: 4000
grad_clip_thresh: 1.0
batch_size: 32 batch_size: 32
epochs: 10000 epochs: 10000
lr: 0.001 lr: 0.001

View File

@ -11,22 +11,23 @@ audio:
outputs_per_step: 1 outputs_per_step: 1
hidden_size: 384 #256 hidden_size: 256
embedding_size: 384 #512 embedding_size: 512
warm_up_step: 4000
grad_clip_thresh: 1.0
batch_size: 32 batch_size: 32
epochs: 10000 epochs: 10000
lr: 0.001 lr: 0.001
save_step: 10 save_step: 1000
image_step: 2000 image_step: 2000
use_gpu: True use_gpu: True
use_data_parallel: True use_data_parallel: False
data_path: ../../../dataset/LJSpeech-1.1 data_path: ../../../dataset/LJSpeech-1.1
save_path: ./checkpoint save_path: ./checkpoint
log_dir: ./log log_dir: ./log
#checkpoint_path: ./checkpoint/transformer/1

View File

@ -1,29 +0,0 @@
from pathlib import Path
import numpy as np
from paddle import fluid
from parakeet.data.sampler import DistributedSampler
from parakeet.data.datacargo import DataCargo
from preprocess import batch_examples, LJSpeech, batch_examples_vocoder
class LJSpeechLoader:
def __init__(self, config, nranks, rank, is_vocoder=False):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
LJSPEECH_ROOT = Path(config.data_path)
dataset = LJSpeech(LJSPEECH_ROOT)
sampler = DistributedSampler(len(dataset), nranks, rank)
assert config.batch_size % nranks == 0
each_bs = config.batch_size // nranks
if is_vocoder:
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_vocoder, drop_last=True)
else:
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True)
self.reader = fluid.io.DataLoader.from_generator(
capacity=32,
iterable=True,
use_double_buffer=True,
return_list=True)
self.reader.set_batch_generator(dataloader, place)

View File

@ -3,11 +3,12 @@ from parakeet.g2p.text.symbols import symbols
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from parakeet.modules.layers import Conv1D, Pool1D from parakeet.modules.layers import Conv, Pool1D
from parakeet.modules.dynamicGRU import DynamicGRU from parakeet.modules.dynamicGRU import DynamicGRU
import numpy as np import numpy as np
class EncoderPrenet(dg.Layer): class EncoderPrenet(dg.Layer):
def __init__(self, embedding_size, num_hidden, use_cudnn=True): def __init__(self, embedding_size, num_hidden, use_cudnn=True):
super(EncoderPrenet, self).__init__() super(EncoderPrenet, self).__init__()
@ -18,19 +19,19 @@ class EncoderPrenet(dg.Layer):
param_attr = fluid.ParamAttr(name='weight'), param_attr = fluid.ParamAttr(name='weight'),
padding_idx = None) padding_idx = None)
self.conv_list = [] self.conv_list = []
self.conv_list.append(Conv1D(in_channels = embedding_size, self.conv_list.append(Conv(in_channels = embedding_size,
out_channels = num_hidden, out_channels = num_hidden,
filter_size = 5, filter_size = 5,
padding = int(np.floor(5/2)), padding = int(np.floor(5/2)),
use_cudnn = use_cudnn, use_cudnn = use_cudnn,
data_format = "NCT")) data_format = "NCT"))
for _ in range(2): for _ in range(2):
self.conv_list = Conv1D(in_channels = num_hidden, self.conv_list.append(Conv(in_channels = num_hidden,
out_channels = num_hidden, out_channels = num_hidden,
filter_size = 5, filter_size = 5,
padding = int(np.floor(5/2)), padding = int(np.floor(5/2)),
use_cudnn = use_cudnn, use_cudnn = use_cudnn,
data_format = "NCT") data_format = "NCT"))
for i, layer in enumerate(self.conv_list): for i, layer in enumerate(self.conv_list):
self.add_sublayer("conv_list_{}".format(i), layer) self.add_sublayer("conv_list_{}".format(i), layer)
@ -71,13 +72,13 @@ class CBHG(dg.Layer):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.projection_size = projection_size self.projection_size = projection_size
self.conv_list = [] self.conv_list = []
self.conv_list.append(Conv1D(in_channels = projection_size, self.conv_list.append(Conv(in_channels = projection_size,
out_channels = hidden_size, out_channels = hidden_size,
filter_size = 1, filter_size = 1,
padding = int(np.floor(1/2)), padding = int(np.floor(1/2)),
data_format = "NCT")) data_format = "NCT"))
for i in range(2,K+1): for i in range(2,K+1):
self.conv_list.append(Conv1D(in_channels = hidden_size, self.conv_list.append(Conv(in_channels = hidden_size,
out_channels = hidden_size, out_channels = hidden_size,
filter_size = i, filter_size = i,
padding = int(np.floor(i/2)), padding = int(np.floor(i/2)),
@ -100,13 +101,13 @@ class CBHG(dg.Layer):
conv_outdim = hidden_size * K conv_outdim = hidden_size * K
self.conv_projection_1 = Conv1D(in_channels = conv_outdim, self.conv_projection_1 = Conv(in_channels = conv_outdim,
out_channels = hidden_size, out_channels = hidden_size,
filter_size = 3, filter_size = 3,
padding = int(np.floor(3/2)), padding = int(np.floor(3/2)),
data_format = "NCT") data_format = "NCT")
self.conv_projection_2 = Conv1D(in_channels = hidden_size, self.conv_projection_2 = Conv(in_channels = hidden_size,
out_channels = projection_size, out_channels = projection_size,
filter_size = 3, filter_size = 3,
padding = int(np.floor(3/2)), padding = int(np.floor(3/2)),

View File

@ -20,13 +20,12 @@ class Encoder(dg.Layer):
self.pos_emb = dg.Embedding(size=[1024, num_hidden], self.pos_emb = dg.Embedding(size=[1024, num_hidden],
padding_idx=0, padding_idx=0,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='weight',
initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp), initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp),
trainable=False)) trainable=False))
self.encoder_prenet = EncoderPrenet(embedding_size = embedding_size, self.encoder_prenet = EncoderPrenet(embedding_size = embedding_size,
num_hidden = num_hidden, num_hidden = num_hidden,
use_cudnn=config.use_gpu) use_cudnn=config.use_gpu)
self.layers = [MultiheadAttention(num_hidden, num_hidden, num_hidden) for _ in range(3)] self.layers = [MultiheadAttention(num_hidden, num_hidden//4, num_hidden//4) for _ in range(3)]
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
self.add_sublayer("self_attn_{}".format(i), layer) self.add_sublayer("self_attn_{}".format(i), layer)
self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*4, filter_size=1, use_cudnn = config.use_gpu) for _ in range(3)] self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*4, filter_size=1, use_cudnn = config.use_gpu) for _ in range(3)]
@ -40,6 +39,7 @@ class Encoder(dg.Layer):
else: else:
query_mask, mask = None, None query_mask, mask = None, None
# Encoder pre_network # Encoder pre_network
x = self.encoder_prenet(x) #(N,T,C) x = self.encoder_prenet(x) #(N,T,C)
@ -81,10 +81,10 @@ class Decoder(dg.Layer):
dropout_rate=0.2) dropout_rate=0.2)
self.linear = dg.Linear(num_hidden, num_hidden) self.linear = dg.Linear(num_hidden, num_hidden)
self.selfattn_layers = [MultiheadAttention(num_hidden, num_hidden, num_hidden) for _ in range(3)] self.selfattn_layers = [MultiheadAttention(num_hidden, num_hidden//4, num_hidden//4) for _ in range(3)]
for i, layer in enumerate(self.selfattn_layers): for i, layer in enumerate(self.selfattn_layers):
self.add_sublayer("self_attn_{}".format(i), layer) self.add_sublayer("self_attn_{}".format(i), layer)
self.attn_layers = [MultiheadAttention(num_hidden, num_hidden, num_hidden) for _ in range(3)] self.attn_layers = [MultiheadAttention(num_hidden, num_hidden//4, num_hidden//4) for _ in range(3)]
for i, layer in enumerate(self.attn_layers): for i, layer in enumerate(self.attn_layers):
self.add_sublayer("attn_{}".format(i), layer) self.add_sublayer("attn_{}".format(i), layer)
self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*4, filter_size=1) for _ in range(3)] self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*4, filter_size=1) for _ in range(3)]
@ -104,18 +104,18 @@ class Decoder(dg.Layer):
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
m_mask = get_non_pad_mask(positional) m_mask = get_non_pad_mask(positional)
mask = get_attn_key_pad_mask(positional, query) mask = get_attn_key_pad_mask((positional==0).astype(np.float32), query)
triu_tensor = dg.to_variable(get_triu_tensor(query.numpy(), query.numpy())).astype(np.float32) triu_tensor = dg.to_variable(get_triu_tensor(query.numpy(), query.numpy())).astype(np.float32)
mask = mask + triu_tensor mask = mask + triu_tensor
mask = fluid.layers.cast(mask != 0, np.float32) mask = fluid.layers.cast(mask == 0, np.float32)
# (batch_size, decoder_len, encoder_len) # (batch_size, decoder_len, encoder_len)
zero_mask = get_attn_key_pad_mask(layers.squeeze(c_mask,[-1]), query) zero_mask = get_attn_key_pad_mask(layers.squeeze(c_mask,[-1]), query)
else: else:
mask = get_triu_tensor(query.numpy(), query.numpy()).astype(np.float32) mask = get_triu_tensor(query.numpy(), query.numpy()).astype(np.float32)
mask = fluid.layers.cast(dg.to_variable(mask != 0), np.float32) mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32)
m_mask, zero_mask = None, None m_mask, zero_mask = None, None
# Decoder pre-network # Decoder pre-network
query = self.decoder_prenet(query) query = self.decoder_prenet(query)
@ -164,6 +164,7 @@ class TransformerTTS(dg.Layer):
# key (batch_size, seq_len, channel) # key (batch_size, seq_len, channel)
# c_mask (batch_size, seq_len) # c_mask (batch_size, seq_len)
# attns_enc (channel / 2, seq_len, seq_len) # attns_enc (channel / 2, seq_len, seq_len)
key, c_mask, attns_enc = self.encoder(characters, pos_text) key, c_mask, attns_enc = self.encoder(characters, pos_text)
# mel_output/postnet_output (batch_size, mel_len, n_mel) # mel_output/postnet_output (batch_size, mel_len, n_mel)

View File

@ -9,9 +9,9 @@ def add_config_options_to_parser(parser):
help="the sampling rate of audio data file.") help="the sampling rate of audio data file.")
parser.add_argument('--audio.preemphasis', type=float, default=0.97, parser.add_argument('--audio.preemphasis', type=float, default=0.97,
help="the preemphasis coefficient.") help="the preemphasis coefficient.")
parser.add_argument('--audio.hop_length', type=float, default=128, parser.add_argument('--audio.hop_length', type=int, default=128,
help="the number of samples to advance between frames.") help="the number of samples to advance between frames.")
parser.add_argument('--audio.win_length', type=float, default=1024, parser.add_argument('--audio.win_length', type=int, default=1024,
help="the length (width) of the window function.") help="the length (width) of the window function.")
parser.add_argument('--audio.power', type=float, default=1.4, parser.add_argument('--audio.power', type=float, default=1.4,
help="the power to raise before griffin-lim.") help="the power to raise before griffin-lim.")
@ -27,6 +27,10 @@ def add_config_options_to_parser(parser):
parser.add_argument('--embedding_size', type=int, default=512, parser.add_argument('--embedding_size', type=int, default=512,
help="the embedding vector size.") help="the embedding vector size.")
parser.add_argument('--warm_up_step', type=int, default=4000,
help="the warm up step of learning rate.")
parser.add_argument('--grad_clip_thresh', type=float, default=1.0,
help="the threshold of grad clip.")
parser.add_argument('--batch_size', type=int, default=32, parser.add_argument('--batch_size', type=int, default=32,
help="batch size for training.") help="batch size for training.")
parser.add_argument('--epochs', type=int, default=10000, parser.add_argument('--epochs', type=int, default=10000,

View File

@ -6,7 +6,7 @@ from pathlib import Path
import jsonargparse import jsonargparse
from parse import add_config_options_to_parser from parse import add_config_options_to_parser
from pprint import pprint from pprint import pprint
from data import LJSpeechLoader from parakeet.models.dataloader.jlspeech import LJSpeechLoader
class MyDataParallel(dg.parallel.DataParallel): class MyDataParallel(dg.parallel.DataParallel):
""" """
@ -50,7 +50,9 @@ def main(cfg):
model = ModelPostNet(cfg) model = ModelPostNet(cfg)
model.train() model.train()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000)) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
parameter_list=model.parameters())
if cfg.checkpoint_path is not None: if cfg.checkpoint_path is not None:
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
@ -75,13 +77,16 @@ def main(cfg):
mag_pred = model(mel) mag_pred = model(mel)
loss = layers.mean(layers.abs(layers.elementwise_sub(mag_pred, mag))) loss = layers.mean(layers.abs(layers.elementwise_sub(mag_pred, mag)))
if cfg.use_data_parallel: if cfg.use_data_parallel:
loss = model.scale_loss(loss) loss = model.scale_loss(loss)
loss.backward() loss.backward()
model.apply_collective_grads() model.apply_collective_grads()
else: else:
loss.backward() loss.backward()
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1)) optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg.grad_clip_thresh))
print("===============",model.pre_proj.conv.weight.numpy())
print("===============",model.pre_proj.conv.weight.gradient())
model.clear_gradients() model.clear_gradients()
if local_rank==0: if local_rank==0:

View File

@ -34,6 +34,9 @@ def main(cfg):
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0 local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1 nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
if local_rank == 0: if local_rank == 0:
# Print the whole config setting. # Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(cfg)) pprint(jsonargparse.namespace_to_dict(cfg))
@ -53,7 +56,8 @@ def main(cfg):
model = TransformerTTS(cfg) model = TransformerTTS(cfg)
model.train() model.train()
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000)) optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
parameter_list=model.parameters())
reader = LJSpeechLoader(cfg, nranks, local_rank).reader() reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
@ -69,6 +73,8 @@ def main(cfg):
for epoch in range(cfg.epochs): for epoch in range(cfg.epochs):
pbar = tqdm(reader) pbar = tqdm(reader)
for i, data in enumerate(pbar): for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d'%epoch) pbar.set_description('Processing at epoch %d'%epoch)
character, mel, mel_input, pos_text, pos_mel, text_length = data character, mel, mel_input, pos_text, pos_mel, text_length = data
@ -86,7 +92,7 @@ def main(cfg):
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel))) post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
stop_loss = cross_entropy(stop_preds, dg.to_variable(label)) stop_loss = cross_entropy(stop_preds, dg.to_variable(label))
loss = mel_loss + post_mel_loss + stop_loss loss = mel_loss + post_mel_loss + stop_loss
if local_rank==0: if local_rank==0:
writer.add_scalars('training_loss', { writer.add_scalars('training_loss', {
'mel_loss':mel_loss.numpy(), 'mel_loss':mel_loss.numpy(),
@ -116,16 +122,16 @@ def main(cfg):
for j in range(4): for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255) x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC") writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
if cfg.use_data_parallel: if cfg.use_data_parallel:
loss = model.scale_loss(loss) loss = model.scale_loss(loss)
loss.backward() loss.backward()
model.apply_collective_grads() model.apply_collective_grads()
else: else:
loss.backward() loss.backward()
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1)) optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg.grad_clip_thresh))
model.clear_gradients() model.clear_gradients()
# save checkpoint # save checkpoint
if local_rank==0 and global_step % cfg.save_step == 0: if local_rank==0 and global_step % cfg.save_step == 0:
if not os.path.exists(cfg.save_path): if not os.path.exists(cfg.save_path):

View File

@ -25,6 +25,14 @@ class DynamicGRU(dg.Layer):
self.is_reverse = is_reverse self.is_reverse = is_reverse
def forward(self, inputs): def forward(self, inputs):
"""
Dynamic GRU block.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
Returns:
output (Variable), Shape(B, T, C), the result compute by GRU.
"""
hidden = self.h_0 hidden = self.h_0
res = [] res = []
for i in range(inputs.shape[1]): for i in range(inputs.shape[1]):

View File

@ -1,6 +1,9 @@
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from parakeet.modules.layers import Conv1D import paddle.fluid as fluid
import math
from parakeet.modules.layers import Conv
class PositionwiseFeedForward(dg.Layer): class PositionwiseFeedForward(dg.Layer):
''' A two-feed-forward-layer module ''' ''' A two-feed-forward-layer module '''
@ -9,14 +12,15 @@ class PositionwiseFeedForward(dg.Layer):
self.num_hidden = num_hidden self.num_hidden = num_hidden
self.use_cudnn = use_cudnn self.use_cudnn = use_cudnn
self.dropout = dropout self.dropout = dropout
self.w_1 = Conv1D(in_channels = d_in, self.w_1 = Conv(in_channels = d_in,
out_channels = num_hidden, out_channels = num_hidden,
filter_size = filter_size, filter_size = filter_size,
padding=padding, padding=padding,
use_cudnn = use_cudnn, use_cudnn = use_cudnn,
data_format = "NTC") data_format = "NTC")
self.w_2 = Conv1D(in_channels = num_hidden,
self.w_2 = Conv(in_channels = num_hidden,
out_channels = d_in, out_channels = d_in,
filter_size = filter_size, filter_size = filter_size,
padding=padding, padding=padding,
@ -25,6 +29,14 @@ class PositionwiseFeedForward(dg.Layer):
self.layer_norm = dg.LayerNorm(d_in) self.layer_norm = dg.LayerNorm(d_in)
def forward(self, input): def forward(self, input):
"""
Feed Forward Network.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
Returns:
output (Variable), Shape(B, T, C), the result after FFN.
"""
#FFN Networt #FFN Networt
x = self.w_2(layers.relu(self.w_1(input))) x = self.w_2(layers.relu(self.w_1(input)))
@ -35,6 +47,6 @@ class PositionwiseFeedForward(dg.Layer):
x = x + input x = x + input
#layer normalization #layer normalization
x = self.layer_norm(x) output = self.layer_norm(x)
return x return output

View File

@ -6,6 +6,42 @@ from paddle import fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
class Conv(dg.Layer):
def __init__(self, in_channels, out_channels, filter_size=1,
padding=0, dilation=1, stride=1, use_cudnn=True,
data_format="NCT", is_bias=True):
super(Conv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_size = filter_size
self.padding = padding
self.dilation = dilation
self.stride = stride
self.use_cudnn = use_cudnn
self.data_format = data_format
self.is_bias = is_bias
self.weight_attr = fluid.ParamAttr(initializer=fluid.initializer.XavierInitializer())
self.bias_attr = None
if is_bias is not False:
k = math.sqrt(1 / in_channels)
self.bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Uniform(low=-k, high=k))
self.conv = Conv1D( in_channels = in_channels,
out_channels = out_channels,
filter_size = filter_size,
padding = padding,
dilation = dilation,
stride = stride,
param_attr = self.weight_attr,
bias_attr = self.bias_attr,
use_cudnn = use_cudnn,
data_format = data_format)
def forward(self, x):
x = self.conv(x)
return x
class Conv1D(dg.Layer): class Conv1D(dg.Layer):
""" """
A convolution 1D block implemented with Conv2D. Form simplicity and A convolution 1D block implemented with Conv2D. Form simplicity and

View File

@ -10,22 +10,35 @@ class ScaledDotProductAttention(dg.Layer):
self.d_key = d_key self.d_key = d_key
# please attention this mask is diff from pytorch # please attention this mask is diff from pytorch
def forward(self, key, value, query, mask=None, query_mask=None): def forward(self, key, value, query, mask=None, query_mask=None, dropout=0.1):
"""
Scaled Dot Product Attention.
Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
dropout (Constant): dtype: float32. The probability of dropout.
Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key.
"""
# Compute attention score # Compute attention score
attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y
attention = attention / math.sqrt(self.d_key) attention = attention / math.sqrt(self.d_key)
# Mask key to ignore padding # Mask key to ignore padding
if mask is not None: if mask is not None:
attention = attention * (mask == 0).astype(np.float32) attention = attention * mask
mask = mask * (-2 ** 32 + 1) mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1)
attention = attention + mask attention = attention + mask
attention = layers.softmax(attention) attention = layers.softmax(attention)
attention = layers.dropout(attention, 0.0) attention = layers.dropout(attention, dropout)
# Mask query to ignore padding # Mask query to ignore padding
# Not sure how to work
if query_mask is not None: if query_mask is not None:
attention = attention * query_mask attention = attention * query_mask
@ -52,6 +65,19 @@ class MultiheadAttention(dg.Layer):
self.layer_norm = dg.LayerNorm(num_hidden) self.layer_norm = dg.LayerNorm(num_hidden)
def forward(self, key, value, query_input, mask=None, query_mask=None): def forward(self, key, value, query_input, mask=None, query_mask=None):
"""
Multihead Attention.
Args:
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key.
"""
batch_size = key.shape[0] batch_size = key.shape[0]
seq_len_key = key.shape[1] seq_len_key = key.shape[1]
seq_len_query = query_input.shape[1] seq_len_query = query_input.shape[1]
@ -62,6 +88,7 @@ class MultiheadAttention(dg.Layer):
if mask is not None: if mask is not None:
mask = layers.expand(mask, (self.num_head, 1, 1)) mask = layers.expand(mask, (self.num_head, 1, 1))
# Make multihead attention # Make multihead attention
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn) # key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
key = layers.reshape(self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k]) key = layers.reshape(self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
@ -71,6 +98,7 @@ class MultiheadAttention(dg.Layer):
key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k]) key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k]) value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q]) query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q])
result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask) result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask)
# concat all multihead result # concat all multihead result

View File

@ -1,7 +1,7 @@
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from parakeet.modules.layers import Conv1D from parakeet.modules.layers import Conv
class PostConvNet(dg.Layer): class PostConvNet(dg.Layer):
def __init__(self, def __init__(self,
@ -17,7 +17,7 @@ class PostConvNet(dg.Layer):
self.dropout = dropout self.dropout = dropout
self.conv_list = [] self.conv_list = []
self.conv_list.append(Conv1D(in_channels = n_mels * outputs_per_step, self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
out_channels = num_hidden, out_channels = num_hidden,
filter_size = filter_size, filter_size = filter_size,
padding = padding, padding = padding,
@ -25,14 +25,14 @@ class PostConvNet(dg.Layer):
data_format = "NCT")) data_format = "NCT"))
for _ in range(1, num_conv-1): for _ in range(1, num_conv-1):
self.conv_list.append(Conv1D(in_channels = num_hidden, self.conv_list.append(Conv(in_channels = num_hidden,
out_channels = num_hidden, out_channels = num_hidden,
filter_size = filter_size, filter_size = filter_size,
padding = padding, padding = padding,
use_cudnn = use_cudnn, use_cudnn = use_cudnn,
data_format = "NCT") ) data_format = "NCT") )
self.conv_list.append(Conv1D(in_channels = num_hidden, self.conv_list.append(Conv(in_channels = num_hidden,
out_channels = n_mels * outputs_per_step, out_channels = n_mels * outputs_per_step,
filter_size = filter_size, filter_size = filter_size,
padding = padding, padding = padding,
@ -59,9 +59,17 @@ class PostConvNet(dg.Layer):
def forward(self, input): def forward(self, input):
"""
Post Conv Net.
Args:
input (Variable): Shape(B, T, C), dtype: float32. The input value.
Returns:
output (Variable), Shape(B, T, C), the result after postconvnet.
"""
input = layers.transpose(input, [0,2,1]) input = layers.transpose(input, [0,2,1])
len = input.shape[-1] len = input.shape[-1]
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list): for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout) input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
input = layers.transpose(input, [0,2,1]) output = layers.transpose(input, [0,2,1])
return input return output

View File

@ -2,9 +2,6 @@ import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
class PreNet(dg.Layer): class PreNet(dg.Layer):
"""
Pre Net before passing through the network
"""
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2): def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
""" """
:param input_size: dimension of input :param input_size: dimension of input
@ -21,6 +18,14 @@ class PreNet(dg.Layer):
self.linear2 = dg.Linear(hidden_size, output_size) self.linear2 = dg.Linear(hidden_size, output_size)
def forward(self, x): def forward(self, x):
"""
Pre Net before passing through the network.
Args:
x (Variable): Shape(B, T, C), dtype: float32. The input value.
Returns:
x (Variable), Shape(B, T, C), the result after pernet.
"""
x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate) x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate)
x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate) x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate)
return x return x