rebuild code and TransformerTTS is right. FastSpeech will later.
This commit is contained in:
parent
5b632d18af
commit
e0aef2e081
|
@ -9,7 +9,7 @@ class AudioProcessor(object):
|
|||
sample_rate=None, # int, sampling rate
|
||||
num_mels=None, # int, bands of mel spectrogram
|
||||
min_level_db=None, # float, minimum level db
|
||||
ref_level_db=None, # float, reference level dbn
|
||||
ref_level_db=None, # float, reference level db
|
||||
n_fft=None, # int: number of samples in a frame for stft
|
||||
win_length=None, # int: the same meaning with n_fft
|
||||
hop_length=None, # int: number of samples between neighboring frame
|
||||
|
@ -22,7 +22,7 @@ class AudioProcessor(object):
|
|||
mel_fmax=None, # int: mel spectrogram's maximum frequency
|
||||
clip_norm=True, # bool: clip spectrogram's norm
|
||||
griffin_lim_iters=None, # int:
|
||||
do_trim_silence=False, # bool: trim silience
|
||||
do_trim_silence=False, # bool: trim silence
|
||||
sound_norm=False,
|
||||
**kwargs):
|
||||
self.sample_rate = sample_rate
|
||||
|
|
|
@ -12,19 +12,19 @@ from parakeet.data.dataset import Dataset
|
|||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
||||
|
||||
class LJSpeechLoader:
|
||||
def __init__(self, config, nranks, rank, is_vocoder=False):
|
||||
def __init__(self, config, nranks, rank, is_vocoder=False, shuffle=True):
|
||||
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
|
||||
|
||||
LJSPEECH_ROOT = Path(config.data_path)
|
||||
dataset = LJSpeech(LJSPEECH_ROOT, config)
|
||||
sampler = DistributedSampler(len(dataset), nranks, rank)
|
||||
sampler = DistributedSampler(len(dataset), nranks, rank, shuffle=shuffle)
|
||||
|
||||
assert config.batch_size % nranks == 0
|
||||
each_bs = config.batch_size // nranks
|
||||
if is_vocoder:
|
||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_vocoder, drop_last=True)
|
||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=shuffle, collate_fn=batch_examples_vocoder, drop_last=True)
|
||||
else:
|
||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True)
|
||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=shuffle, collate_fn=batch_examples, drop_last=True)
|
||||
|
||||
self.reader = fluid.io.DataLoader.from_generator(
|
||||
capacity=32,
|
||||
|
@ -41,6 +41,25 @@ class LJSpeech(Dataset):
|
|||
self.root = root if isinstance(root, Path) else Path(root)
|
||||
self.metadata = self._prepare_metadata()
|
||||
self.config = config
|
||||
self._ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=config.audio.sr,
|
||||
num_mels=config.audio.num_mels,
|
||||
min_level_db=config.audio.min_level_db,
|
||||
ref_level_db=config.audio.ref_level_db,
|
||||
n_fft=config.audio.n_fft,
|
||||
win_length= config.audio.win_length,
|
||||
hop_length= config.audio.hop_length,
|
||||
power=config.audio.power,
|
||||
preemphasis=config.audio.preemphasis,
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
|
||||
def _prepare_metadata(self):
|
||||
csv_path = self.root.joinpath("metadata.csv")
|
||||
|
@ -59,29 +78,10 @@ class LJSpeech(Dataset):
|
|||
fname, raw_text, normalized_text = metadatum
|
||||
wav_path = self.root.joinpath("wavs", fname + ".wav")
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=22050,
|
||||
num_mels=80,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
n_fft=2048,
|
||||
win_length= int(22050 * 0.05),
|
||||
hop_length= int(22050 * 0.0125),
|
||||
power=1.2,
|
||||
preemphasis=0.97,
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
|
||||
wav = _ljspeech_processor.load_wav(str(wav_path))
|
||||
mag = _ljspeech_processor.spectrogram(wav).astype(np.float32)
|
||||
mel = _ljspeech_processor.melspectrogram(wav).astype(np.float32)
|
||||
wav = self._ljspeech_processor.load_wav(str(wav_path))
|
||||
mag = self._ljspeech_processor.spectrogram(wav).astype(np.float32)
|
||||
mel = self._ljspeech_processor.melspectrogram(wav).astype(np.float32)
|
||||
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
|
||||
|
||||
|
@ -123,11 +123,11 @@ def batch_examples(batch):
|
|||
text_lens = sorted(text_lens, reverse=True)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts = TextIDBatcher(pad_id=0)(texts)
|
||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts)
|
||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels)
|
||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1))
|
||||
mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))
|
||||
texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
|
||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
|
||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1)) #(B,T,num_mels)
|
||||
mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))#(B,T,num_mels)
|
||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens))
|
||||
|
||||
def batch_examples_vocoder(batch):
|
||||
|
|
|
@ -1,124 +0,0 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import os
|
||||
|
||||
import hparams
|
||||
import Audio
|
||||
from text import text_to_sequence
|
||||
from utils import process_text, pad_1D, pad_2D
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
class FastSpeechDataset(Dataset):
|
||||
""" LJSpeech """
|
||||
|
||||
def __init__(self):
|
||||
self.text = process_text(os.path.join("data", "train.txt"))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.text)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
mel_gt_name = os.path.join(
|
||||
hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (idx+1))
|
||||
mel_gt_target = np.load(mel_gt_name)
|
||||
D = np.load(os.path.join(hparams.alignment_path, str(idx)+".npy"))
|
||||
|
||||
character = self.text[idx][0:len(self.text[idx])-1]
|
||||
character = np.array(text_to_sequence(
|
||||
character, hparams.text_cleaners))
|
||||
|
||||
sample = {"text": character,
|
||||
"mel_target": mel_gt_target,
|
||||
"D": D}
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def reprocess(batch, cut_list):
|
||||
texts = [batch[ind]["text"] for ind in cut_list]
|
||||
mel_targets = [batch[ind]["mel_target"] for ind in cut_list]
|
||||
Ds = [batch[ind]["D"] for ind in cut_list]
|
||||
|
||||
length_text = np.array([])
|
||||
for text in texts:
|
||||
length_text = np.append(length_text, text.shape[0])
|
||||
|
||||
src_pos = list()
|
||||
max_len = int(max(length_text))
|
||||
for length_src_row in length_text:
|
||||
src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
|
||||
(0, max_len-int(length_src_row)), 'constant'))
|
||||
src_pos = np.array(src_pos)
|
||||
|
||||
length_mel = np.array(list())
|
||||
for mel in mel_targets:
|
||||
length_mel = np.append(length_mel, mel.shape[0])
|
||||
|
||||
mel_pos = list()
|
||||
max_mel_len = int(max(length_mel))
|
||||
for length_mel_row in length_mel:
|
||||
mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
|
||||
(0, max_mel_len-int(length_mel_row)), 'constant'))
|
||||
mel_pos = np.array(mel_pos)
|
||||
|
||||
texts = pad_1D(texts)
|
||||
Ds = pad_1D(Ds)
|
||||
mel_targets = pad_2D(mel_targets)
|
||||
|
||||
out = {"text": texts,
|
||||
"mel_target": mel_targets,
|
||||
"D": Ds,
|
||||
"mel_pos": mel_pos,
|
||||
"src_pos": src_pos,
|
||||
"mel_max_len": max_mel_len}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
len_arr = np.array([d["text"].shape[0] for d in batch])
|
||||
index_arr = np.argsort(-len_arr)
|
||||
batchsize = len(batch)
|
||||
real_batchsize = int(math.sqrt(batchsize))
|
||||
|
||||
cut_list = list()
|
||||
for i in range(real_batchsize):
|
||||
cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])
|
||||
|
||||
output = list()
|
||||
for i in range(real_batchsize):
|
||||
output.append(reprocess(batch, cut_list[i]))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test
|
||||
dataset = FastSpeechDataset()
|
||||
training_loader = DataLoader(dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=0)
|
||||
total_step = hparams.epochs * len(training_loader) * hparams.batch_size
|
||||
|
||||
cnt = 0
|
||||
for i, batchs in enumerate(training_loader):
|
||||
for j, data_of_batch in enumerate(batchs):
|
||||
mel_target = torch.from_numpy(
|
||||
data_of_batch["mel_target"]).float().to(device)
|
||||
D = torch.from_numpy(data_of_batch["D"]).int().to(device)
|
||||
# print(mel_target.size())
|
||||
# print(D.sum())
|
||||
print(cnt)
|
||||
if mel_target.size(1) == D.sum().item():
|
||||
cnt += 1
|
||||
|
||||
print(cnt)
|
|
@ -11,20 +11,33 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward
|
|||
|
||||
|
||||
class FFTBlock(dg.Layer):
|
||||
"""FFT Block"""
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2):
|
||||
super(FFTBlock, self).__init__()
|
||||
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout)
|
||||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
|
||||
enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
enc_output *= non_pad_mask
|
||||
"""
|
||||
Feed Forward Transformer block in FastSpeech.
|
||||
|
||||
Args:
|
||||
enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input.
|
||||
T means the timesteps of input.
|
||||
non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence.
|
||||
slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention.
|
||||
len_q means the sequence length of query, len_k means the sequence length of key.
|
||||
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
enc_output *= non_pad_mask
|
||||
Returns:
|
||||
output (Variable), Shape(B, T, C), the output after self-attention & ffn.
|
||||
slf_attn (Variable), Shape(B * n_head, T, T), the self attention.
|
||||
"""
|
||||
output, slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
output *= non_pad_mask
|
||||
|
||||
return enc_output, enc_slf_attn
|
||||
output = self.pos_ffn(output)
|
||||
output *= non_pad_mask
|
||||
|
||||
return output, slf_attn
|
||||
|
||||
|
||||
class LengthRegulator(dg.Layer):
|
||||
|
@ -70,6 +83,20 @@ class LengthRegulator(dg.Layer):
|
|||
|
||||
|
||||
def forward(self, x, alpha=1.0, target=None):
|
||||
"""
|
||||
Length Regulator block in FastSpeech.
|
||||
|
||||
Args:
|
||||
x (Variable): Shape(B, T, C), dtype: float32. The encoder output.
|
||||
alpha (Constant): dtype: float32. The hyperparameter to determine the length of
|
||||
the expanded sequence mel, thereby controlling the voice speed.
|
||||
target (Variable): (Variable, optional): Shape(B, T_text),
|
||||
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
|
||||
|
||||
Returns:
|
||||
output (Variable), Shape(B, T, C), the output after exppand.
|
||||
duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor.
|
||||
"""
|
||||
duration_predictor_output = self.duration_predictor(x)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
output = self.LR(x, target)
|
||||
|
@ -81,7 +108,6 @@ class LengthRegulator(dg.Layer):
|
|||
return output, mel_pos
|
||||
|
||||
class DurationPredictor(dg.Layer):
|
||||
""" Duration Predictor """
|
||||
def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
|
||||
super(DurationPredictor, self).__init__()
|
||||
self.input_size = input_size
|
||||
|
@ -105,7 +131,14 @@ class DurationPredictor(dg.Layer):
|
|||
self.linear =dg.Linear(self.out_channels, 1)
|
||||
|
||||
def forward(self, encoder_output):
|
||||
"""
|
||||
Duration Predictor block in FastSpeech.
|
||||
|
||||
Args:
|
||||
encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output.
|
||||
Returns:
|
||||
out (Variable), Shape(B, T, C), the output of duration predictor.
|
||||
"""
|
||||
# encoder_output.shape(N, T, C)
|
||||
out = layers.dropout(layers.relu(self.layer_norm1(self.conv1(encoder_output))), self.dropout)
|
||||
out = layers.dropout(layers.relu(self.layer_norm2(self.conv2(out))), self.dropout)
|
||||
|
|
|
@ -35,6 +35,20 @@ class Encoder(dg.Layer):
|
|||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, character, text_pos):
|
||||
"""
|
||||
Encoder layer of FastSpeech.
|
||||
|
||||
Args:
|
||||
character (Variable): Shape(B, T_text), dtype: float32. The input text
|
||||
characters. T_text means the timesteps of input characters.
|
||||
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
|
||||
position. T_text means the timesteps of input characters.
|
||||
|
||||
Returns:
|
||||
enc_output (Variable), Shape(B, text_T, C), the encoder output.
|
||||
non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad.
|
||||
enc_slf_attn_list (list<Variable>), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
# -- prepare masks
|
||||
# shape character (N, T)
|
||||
|
@ -80,6 +94,18 @@ class Decoder(dg.Layer):
|
|||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, enc_seq, enc_pos):
|
||||
"""
|
||||
Decoder layer of FastSpeech.
|
||||
|
||||
Args:
|
||||
enc_seq (Variable), Shape(B, text_T, C), dtype: float32.
|
||||
The output of length regulator.
|
||||
enc_pos (Variable, optional): Shape(B, T_mel),
|
||||
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
|
||||
Returns:
|
||||
dec_output (Variable), Shape(B, mel_T, C), the decoder output.
|
||||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
|
||||
# -- Prepare masks
|
||||
|
@ -141,6 +167,31 @@ class FastSpeech(dg.Layer):
|
|||
dropout=0.1)
|
||||
|
||||
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
|
||||
"""
|
||||
FastSpeech model.
|
||||
|
||||
Args:
|
||||
character (Variable): Shape(B, T_text), dtype: float32. The input text
|
||||
characters. T_text means the timesteps of input characters.
|
||||
text_pos (Variable): Shape(B, T_text), dtype: int64. The input text
|
||||
position. T_text means the timesteps of input characters.
|
||||
mel_pos (Variable, optional): Shape(B, T_mel),
|
||||
dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum.
|
||||
length_target (Variable, optional): Shape(B, T_text),
|
||||
dtype: int64. The duration of phoneme compute from pretrained transformerTTS.
|
||||
alpha (Constant):
|
||||
dtype: float32. The hyperparameter to determine the length of the expanded sequence
|
||||
mel, thereby controlling the voice speed.
|
||||
|
||||
Returns:
|
||||
mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet.
|
||||
mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet.
|
||||
duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute
|
||||
with duration predictor.
|
||||
enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list.
|
||||
dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, non_pad_mask, enc_slf_attn_list = self.encoder(character, text_pos)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
|
||||
|
|
|
@ -9,9 +9,9 @@ def add_config_options_to_parser(parser):
|
|||
help="the sampling rate of audio data file.")
|
||||
parser.add_argument('--audio.preemphasis', type=float, default=0.97,
|
||||
help="the preemphasis coefficient.")
|
||||
parser.add_argument('--audio.hop_length', type=float, default=128,
|
||||
parser.add_argument('--audio.hop_length', type=int, default=128,
|
||||
help="the number of samples to advance between frames.")
|
||||
parser.add_argument('--audio.win_length', type=float, default=1024,
|
||||
parser.add_argument('--audio.win_length', type=int, default=1024,
|
||||
help="the length (width) of the window function.")
|
||||
parser.add_argument('--audio.power', type=float, default=1.4,
|
||||
help="the power to raise before griffin-lim.")
|
||||
|
|
|
@ -66,8 +66,8 @@ def main(cfg):
|
|||
|
||||
model = FastSpeech(cfg)
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step))
|
||||
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
|
||||
parameter_list=model.parameters())
|
||||
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
|
||||
|
||||
if cfg.checkpoint_path is not None:
|
||||
|
|
|
@ -13,7 +13,8 @@ audio:
|
|||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
|
||||
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
||||
batch_size: 32
|
||||
epochs: 10000
|
||||
lr: 0.001
|
||||
|
|
|
@ -11,22 +11,23 @@ audio:
|
|||
outputs_per_step: 1
|
||||
|
||||
|
||||
hidden_size: 384 #256
|
||||
embedding_size: 384 #512
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
|
||||
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
||||
batch_size: 32
|
||||
epochs: 10000
|
||||
lr: 0.001
|
||||
save_step: 10
|
||||
save_step: 1000
|
||||
image_step: 2000
|
||||
use_gpu: True
|
||||
use_data_parallel: True
|
||||
use_data_parallel: False
|
||||
|
||||
data_path: ../../../dataset/LJSpeech-1.1
|
||||
save_path: ./checkpoint
|
||||
log_dir: ./log
|
||||
|
||||
|
||||
#checkpoint_path: ./checkpoint/transformer/1
|
||||
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
from pathlib import Path
|
||||
import numpy as np
|
||||
from paddle import fluid
|
||||
from parakeet.data.sampler import DistributedSampler
|
||||
from parakeet.data.datacargo import DataCargo
|
||||
from preprocess import batch_examples, LJSpeech, batch_examples_vocoder
|
||||
|
||||
class LJSpeechLoader:
|
||||
def __init__(self, config, nranks, rank, is_vocoder=False):
|
||||
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
|
||||
|
||||
LJSPEECH_ROOT = Path(config.data_path)
|
||||
dataset = LJSpeech(LJSPEECH_ROOT)
|
||||
sampler = DistributedSampler(len(dataset), nranks, rank)
|
||||
|
||||
assert config.batch_size % nranks == 0
|
||||
each_bs = config.batch_size // nranks
|
||||
if is_vocoder:
|
||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples_vocoder, drop_last=True)
|
||||
else:
|
||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=True, collate_fn=batch_examples, drop_last=True)
|
||||
|
||||
self.reader = fluid.io.DataLoader.from_generator(
|
||||
capacity=32,
|
||||
iterable=True,
|
||||
use_double_buffer=True,
|
||||
return_list=True)
|
||||
self.reader.set_batch_generator(dataloader, place)
|
||||
|
|
@ -3,11 +3,12 @@ from parakeet.g2p.text.symbols import symbols
|
|||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
from parakeet.modules.layers import Conv1D, Pool1D
|
||||
from parakeet.modules.layers import Conv, Pool1D
|
||||
from parakeet.modules.dynamicGRU import DynamicGRU
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class EncoderPrenet(dg.Layer):
|
||||
def __init__(self, embedding_size, num_hidden, use_cudnn=True):
|
||||
super(EncoderPrenet, self).__init__()
|
||||
|
@ -18,19 +19,19 @@ class EncoderPrenet(dg.Layer):
|
|||
param_attr = fluid.ParamAttr(name='weight'),
|
||||
padding_idx = None)
|
||||
self.conv_list = []
|
||||
self.conv_list.append(Conv1D(in_channels = embedding_size,
|
||||
self.conv_list.append(Conv(in_channels = embedding_size,
|
||||
out_channels = num_hidden,
|
||||
filter_size = 5,
|
||||
padding = int(np.floor(5/2)),
|
||||
use_cudnn = use_cudnn,
|
||||
data_format = "NCT"))
|
||||
for _ in range(2):
|
||||
self.conv_list = Conv1D(in_channels = num_hidden,
|
||||
self.conv_list.append(Conv(in_channels = num_hidden,
|
||||
out_channels = num_hidden,
|
||||
filter_size = 5,
|
||||
padding = int(np.floor(5/2)),
|
||||
use_cudnn = use_cudnn,
|
||||
data_format = "NCT")
|
||||
data_format = "NCT"))
|
||||
|
||||
for i, layer in enumerate(self.conv_list):
|
||||
self.add_sublayer("conv_list_{}".format(i), layer)
|
||||
|
@ -71,13 +72,13 @@ class CBHG(dg.Layer):
|
|||
self.hidden_size = hidden_size
|
||||
self.projection_size = projection_size
|
||||
self.conv_list = []
|
||||
self.conv_list.append(Conv1D(in_channels = projection_size,
|
||||
self.conv_list.append(Conv(in_channels = projection_size,
|
||||
out_channels = hidden_size,
|
||||
filter_size = 1,
|
||||
padding = int(np.floor(1/2)),
|
||||
data_format = "NCT"))
|
||||
for i in range(2,K+1):
|
||||
self.conv_list.append(Conv1D(in_channels = hidden_size,
|
||||
self.conv_list.append(Conv(in_channels = hidden_size,
|
||||
out_channels = hidden_size,
|
||||
filter_size = i,
|
||||
padding = int(np.floor(i/2)),
|
||||
|
@ -100,13 +101,13 @@ class CBHG(dg.Layer):
|
|||
|
||||
conv_outdim = hidden_size * K
|
||||
|
||||
self.conv_projection_1 = Conv1D(in_channels = conv_outdim,
|
||||
self.conv_projection_1 = Conv(in_channels = conv_outdim,
|
||||
out_channels = hidden_size,
|
||||
filter_size = 3,
|
||||
padding = int(np.floor(3/2)),
|
||||
data_format = "NCT")
|
||||
|
||||
self.conv_projection_2 = Conv1D(in_channels = hidden_size,
|
||||
self.conv_projection_2 = Conv(in_channels = hidden_size,
|
||||
out_channels = projection_size,
|
||||
filter_size = 3,
|
||||
padding = int(np.floor(3/2)),
|
||||
|
|
|
@ -20,13 +20,12 @@ class Encoder(dg.Layer):
|
|||
self.pos_emb = dg.Embedding(size=[1024, num_hidden],
|
||||
padding_idx=0,
|
||||
param_attr=fluid.ParamAttr(
|
||||
name='weight',
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(self.pos_inp),
|
||||
trainable=False))
|
||||
self.encoder_prenet = EncoderPrenet(embedding_size = embedding_size,
|
||||
num_hidden = num_hidden,
|
||||
use_cudnn=config.use_gpu)
|
||||
self.layers = [MultiheadAttention(num_hidden, num_hidden, num_hidden) for _ in range(3)]
|
||||
self.layers = [MultiheadAttention(num_hidden, num_hidden//4, num_hidden//4) for _ in range(3)]
|
||||
for i, layer in enumerate(self.layers):
|
||||
self.add_sublayer("self_attn_{}".format(i), layer)
|
||||
self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*4, filter_size=1, use_cudnn = config.use_gpu) for _ in range(3)]
|
||||
|
@ -40,6 +39,7 @@ class Encoder(dg.Layer):
|
|||
else:
|
||||
query_mask, mask = None, None
|
||||
|
||||
|
||||
# Encoder pre_network
|
||||
x = self.encoder_prenet(x) #(N,T,C)
|
||||
|
||||
|
@ -81,10 +81,10 @@ class Decoder(dg.Layer):
|
|||
dropout_rate=0.2)
|
||||
self.linear = dg.Linear(num_hidden, num_hidden)
|
||||
|
||||
self.selfattn_layers = [MultiheadAttention(num_hidden, num_hidden, num_hidden) for _ in range(3)]
|
||||
self.selfattn_layers = [MultiheadAttention(num_hidden, num_hidden//4, num_hidden//4) for _ in range(3)]
|
||||
for i, layer in enumerate(self.selfattn_layers):
|
||||
self.add_sublayer("self_attn_{}".format(i), layer)
|
||||
self.attn_layers = [MultiheadAttention(num_hidden, num_hidden, num_hidden) for _ in range(3)]
|
||||
self.attn_layers = [MultiheadAttention(num_hidden, num_hidden//4, num_hidden//4) for _ in range(3)]
|
||||
for i, layer in enumerate(self.attn_layers):
|
||||
self.add_sublayer("attn_{}".format(i), layer)
|
||||
self.ffns = [PositionwiseFeedForward(num_hidden, num_hidden*4, filter_size=1) for _ in range(3)]
|
||||
|
@ -104,18 +104,18 @@ class Decoder(dg.Layer):
|
|||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
m_mask = get_non_pad_mask(positional)
|
||||
mask = get_attn_key_pad_mask(positional, query)
|
||||
mask = get_attn_key_pad_mask((positional==0).astype(np.float32), query)
|
||||
triu_tensor = dg.to_variable(get_triu_tensor(query.numpy(), query.numpy())).astype(np.float32)
|
||||
mask = mask + triu_tensor
|
||||
mask = fluid.layers.cast(mask != 0, np.float32)
|
||||
mask = fluid.layers.cast(mask == 0, np.float32)
|
||||
|
||||
|
||||
# (batch_size, decoder_len, encoder_len)
|
||||
zero_mask = get_attn_key_pad_mask(layers.squeeze(c_mask,[-1]), query)
|
||||
else:
|
||||
mask = get_triu_tensor(query.numpy(), query.numpy()).astype(np.float32)
|
||||
mask = fluid.layers.cast(dg.to_variable(mask != 0), np.float32)
|
||||
mask = fluid.layers.cast(dg.to_variable(mask == 0), np.float32)
|
||||
m_mask, zero_mask = None, None
|
||||
|
||||
# Decoder pre-network
|
||||
query = self.decoder_prenet(query)
|
||||
|
||||
|
@ -164,6 +164,7 @@ class TransformerTTS(dg.Layer):
|
|||
# key (batch_size, seq_len, channel)
|
||||
# c_mask (batch_size, seq_len)
|
||||
# attns_enc (channel / 2, seq_len, seq_len)
|
||||
|
||||
key, c_mask, attns_enc = self.encoder(characters, pos_text)
|
||||
|
||||
# mel_output/postnet_output (batch_size, mel_len, n_mel)
|
||||
|
|
|
@ -9,9 +9,9 @@ def add_config_options_to_parser(parser):
|
|||
help="the sampling rate of audio data file.")
|
||||
parser.add_argument('--audio.preemphasis', type=float, default=0.97,
|
||||
help="the preemphasis coefficient.")
|
||||
parser.add_argument('--audio.hop_length', type=float, default=128,
|
||||
parser.add_argument('--audio.hop_length', type=int, default=128,
|
||||
help="the number of samples to advance between frames.")
|
||||
parser.add_argument('--audio.win_length', type=float, default=1024,
|
||||
parser.add_argument('--audio.win_length', type=int, default=1024,
|
||||
help="the length (width) of the window function.")
|
||||
parser.add_argument('--audio.power', type=float, default=1.4,
|
||||
help="the power to raise before griffin-lim.")
|
||||
|
@ -27,6 +27,10 @@ def add_config_options_to_parser(parser):
|
|||
parser.add_argument('--embedding_size', type=int, default=512,
|
||||
help="the embedding vector size.")
|
||||
|
||||
parser.add_argument('--warm_up_step', type=int, default=4000,
|
||||
help="the warm up step of learning rate.")
|
||||
parser.add_argument('--grad_clip_thresh', type=float, default=1.0,
|
||||
help="the threshold of grad clip.")
|
||||
parser.add_argument('--batch_size', type=int, default=32,
|
||||
help="batch size for training.")
|
||||
parser.add_argument('--epochs', type=int, default=10000,
|
||||
|
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
|||
import jsonargparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.models.dataloader.jlspeech import LJSpeechLoader
|
||||
|
||||
class MyDataParallel(dg.parallel.DataParallel):
|
||||
"""
|
||||
|
@ -50,7 +50,9 @@ def main(cfg):
|
|||
model = ModelPostNet(cfg)
|
||||
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000))
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
|
||||
parameter_list=model.parameters())
|
||||
|
||||
|
||||
if cfg.checkpoint_path is not None:
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
|
||||
|
@ -75,13 +77,16 @@ def main(cfg):
|
|||
|
||||
mag_pred = model(mel)
|
||||
loss = layers.mean(layers.abs(layers.elementwise_sub(mag_pred, mag)))
|
||||
|
||||
if cfg.use_data_parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1))
|
||||
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg.grad_clip_thresh))
|
||||
print("===============",model.pre_proj.conv.weight.numpy())
|
||||
print("===============",model.pre_proj.conv.weight.gradient())
|
||||
model.clear_gradients()
|
||||
|
||||
if local_rank==0:
|
||||
|
|
|
@ -34,6 +34,9 @@ def main(cfg):
|
|||
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
|
||||
|
||||
fluid.default_startup_program().random_seed = 1
|
||||
fluid.default_main_program().random_seed = 1
|
||||
|
||||
if local_rank == 0:
|
||||
# Print the whole config setting.
|
||||
pprint(jsonargparse.namespace_to_dict(cfg))
|
||||
|
@ -53,7 +56,8 @@ def main(cfg):
|
|||
model = TransformerTTS(cfg)
|
||||
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(4000 *( cfg.lr ** 2)), 4000))
|
||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
|
||||
parameter_list=model.parameters())
|
||||
|
||||
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
|
||||
|
||||
|
@ -69,6 +73,8 @@ def main(cfg):
|
|||
|
||||
for epoch in range(cfg.epochs):
|
||||
pbar = tqdm(reader)
|
||||
|
||||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d'%epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length = data
|
||||
|
@ -86,7 +92,7 @@ def main(cfg):
|
|||
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||
stop_loss = cross_entropy(stop_preds, dg.to_variable(label))
|
||||
loss = mel_loss + post_mel_loss + stop_loss
|
||||
|
||||
|
||||
if local_rank==0:
|
||||
writer.add_scalars('training_loss', {
|
||||
'mel_loss':mel_loss.numpy(),
|
||||
|
@ -116,16 +122,16 @@ def main(cfg):
|
|||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
|
||||
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
|
||||
|
||||
|
||||
if cfg.use_data_parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(1))
|
||||
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg.grad_clip_thresh))
|
||||
model.clear_gradients()
|
||||
|
||||
|
||||
# save checkpoint
|
||||
if local_rank==0 and global_step % cfg.save_step == 0:
|
||||
if not os.path.exists(cfg.save_path):
|
||||
|
|
|
@ -25,6 +25,14 @@ class DynamicGRU(dg.Layer):
|
|||
self.is_reverse = is_reverse
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Dynamic GRU block.
|
||||
|
||||
Args:
|
||||
input (Variable): Shape(B, T, C), dtype: float32. The input value.
|
||||
Returns:
|
||||
output (Variable), Shape(B, T, C), the result compute by GRU.
|
||||
"""
|
||||
hidden = self.h_0
|
||||
res = []
|
||||
for i in range(inputs.shape[1]):
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
from parakeet.modules.layers import Conv1D
|
||||
import paddle.fluid as fluid
|
||||
import math
|
||||
from parakeet.modules.layers import Conv
|
||||
|
||||
|
||||
class PositionwiseFeedForward(dg.Layer):
|
||||
''' A two-feed-forward-layer module '''
|
||||
|
@ -9,14 +12,15 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
self.num_hidden = num_hidden
|
||||
self.use_cudnn = use_cudnn
|
||||
self.dropout = dropout
|
||||
|
||||
self.w_1 = Conv1D(in_channels = d_in,
|
||||
|
||||
self.w_1 = Conv(in_channels = d_in,
|
||||
out_channels = num_hidden,
|
||||
filter_size = filter_size,
|
||||
padding=padding,
|
||||
use_cudnn = use_cudnn,
|
||||
data_format = "NTC")
|
||||
self.w_2 = Conv1D(in_channels = num_hidden,
|
||||
|
||||
self.w_2 = Conv(in_channels = num_hidden,
|
||||
out_channels = d_in,
|
||||
filter_size = filter_size,
|
||||
padding=padding,
|
||||
|
@ -25,6 +29,14 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
self.layer_norm = dg.LayerNorm(d_in)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Feed Forward Network.
|
||||
|
||||
Args:
|
||||
input (Variable): Shape(B, T, C), dtype: float32. The input value.
|
||||
Returns:
|
||||
output (Variable), Shape(B, T, C), the result after FFN.
|
||||
"""
|
||||
#FFN Networt
|
||||
x = self.w_2(layers.relu(self.w_1(input)))
|
||||
|
||||
|
@ -35,6 +47,6 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
x = x + input
|
||||
|
||||
#layer normalization
|
||||
x = self.layer_norm(x)
|
||||
output = self.layer_norm(x)
|
||||
|
||||
return x
|
||||
return output
|
|
@ -6,6 +6,42 @@ from paddle import fluid
|
|||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
class Conv(dg.Layer):
|
||||
def __init__(self, in_channels, out_channels, filter_size=1,
|
||||
padding=0, dilation=1, stride=1, use_cudnn=True,
|
||||
data_format="NCT", is_bias=True):
|
||||
super(Conv, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_size = filter_size
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.stride = stride
|
||||
self.use_cudnn = use_cudnn
|
||||
self.data_format = data_format
|
||||
self.is_bias = is_bias
|
||||
|
||||
self.weight_attr = fluid.ParamAttr(initializer=fluid.initializer.XavierInitializer())
|
||||
self.bias_attr = None
|
||||
if is_bias is not False:
|
||||
k = math.sqrt(1 / in_channels)
|
||||
self.bias_attr = fluid.ParamAttr(initializer=fluid.initializer.Uniform(low=-k, high=k))
|
||||
|
||||
self.conv = Conv1D( in_channels = in_channels,
|
||||
out_channels = out_channels,
|
||||
filter_size = filter_size,
|
||||
padding = padding,
|
||||
dilation = dilation,
|
||||
stride = stride,
|
||||
param_attr = self.weight_attr,
|
||||
bias_attr = self.bias_attr,
|
||||
use_cudnn = use_cudnn,
|
||||
data_format = data_format)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class Conv1D(dg.Layer):
|
||||
"""
|
||||
A convolution 1D block implemented with Conv2D. Form simplicity and
|
||||
|
|
|
@ -10,22 +10,35 @@ class ScaledDotProductAttention(dg.Layer):
|
|||
self.d_key = d_key
|
||||
|
||||
# please attention this mask is diff from pytorch
|
||||
def forward(self, key, value, query, mask=None, query_mask=None):
|
||||
def forward(self, key, value, query, mask=None, query_mask=None, dropout=0.1):
|
||||
"""
|
||||
Scaled Dot Product Attention.
|
||||
|
||||
Args:
|
||||
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
|
||||
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
|
||||
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
|
||||
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
|
||||
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
|
||||
dropout (Constant): dtype: float32. The probability of dropout.
|
||||
Returns:
|
||||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
# Compute attention score
|
||||
attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y
|
||||
attention = attention / math.sqrt(self.d_key)
|
||||
|
||||
# Mask key to ignore padding
|
||||
if mask is not None:
|
||||
attention = attention * (mask == 0).astype(np.float32)
|
||||
mask = mask * (-2 ** 32 + 1)
|
||||
attention = attention * mask
|
||||
mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1)
|
||||
attention = attention + mask
|
||||
|
||||
|
||||
attention = layers.softmax(attention)
|
||||
attention = layers.dropout(attention, 0.0)
|
||||
attention = layers.dropout(attention, dropout)
|
||||
# Mask query to ignore padding
|
||||
# Not sure how to work
|
||||
if query_mask is not None:
|
||||
attention = attention * query_mask
|
||||
|
||||
|
@ -52,6 +65,19 @@ class MultiheadAttention(dg.Layer):
|
|||
self.layer_norm = dg.LayerNorm(num_hidden)
|
||||
|
||||
def forward(self, key, value, query_input, mask=None, query_mask=None):
|
||||
"""
|
||||
Multihead Attention.
|
||||
|
||||
Args:
|
||||
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
|
||||
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
|
||||
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
|
||||
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
|
||||
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
|
||||
Returns:
|
||||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
batch_size = key.shape[0]
|
||||
seq_len_key = key.shape[1]
|
||||
seq_len_query = query_input.shape[1]
|
||||
|
@ -62,6 +88,7 @@ class MultiheadAttention(dg.Layer):
|
|||
if mask is not None:
|
||||
mask = layers.expand(mask, (self.num_head, 1, 1))
|
||||
|
||||
|
||||
# Make multihead attention
|
||||
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
|
||||
key = layers.reshape(self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
|
||||
|
@ -71,6 +98,7 @@ class MultiheadAttention(dg.Layer):
|
|||
key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
||||
value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
||||
query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q])
|
||||
|
||||
result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask)
|
||||
|
||||
# concat all multihead result
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
from parakeet.modules.layers import Conv1D
|
||||
from parakeet.modules.layers import Conv
|
||||
|
||||
class PostConvNet(dg.Layer):
|
||||
def __init__(self,
|
||||
|
@ -17,7 +17,7 @@ class PostConvNet(dg.Layer):
|
|||
|
||||
self.dropout = dropout
|
||||
self.conv_list = []
|
||||
self.conv_list.append(Conv1D(in_channels = n_mels * outputs_per_step,
|
||||
self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
|
||||
out_channels = num_hidden,
|
||||
filter_size = filter_size,
|
||||
padding = padding,
|
||||
|
@ -25,14 +25,14 @@ class PostConvNet(dg.Layer):
|
|||
data_format = "NCT"))
|
||||
|
||||
for _ in range(1, num_conv-1):
|
||||
self.conv_list.append(Conv1D(in_channels = num_hidden,
|
||||
self.conv_list.append(Conv(in_channels = num_hidden,
|
||||
out_channels = num_hidden,
|
||||
filter_size = filter_size,
|
||||
padding = padding,
|
||||
use_cudnn = use_cudnn,
|
||||
data_format = "NCT") )
|
||||
|
||||
self.conv_list.append(Conv1D(in_channels = num_hidden,
|
||||
self.conv_list.append(Conv(in_channels = num_hidden,
|
||||
out_channels = n_mels * outputs_per_step,
|
||||
filter_size = filter_size,
|
||||
padding = padding,
|
||||
|
@ -59,9 +59,17 @@ class PostConvNet(dg.Layer):
|
|||
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Post Conv Net.
|
||||
|
||||
Args:
|
||||
input (Variable): Shape(B, T, C), dtype: float32. The input value.
|
||||
Returns:
|
||||
output (Variable), Shape(B, T, C), the result after postconvnet.
|
||||
"""
|
||||
input = layers.transpose(input, [0,2,1])
|
||||
len = input.shape[-1]
|
||||
for batch_norm, conv in zip(self.batch_norm_list, self.conv_list):
|
||||
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
|
||||
input = layers.transpose(input, [0,2,1])
|
||||
return input
|
||||
output = layers.transpose(input, [0,2,1])
|
||||
return output
|
|
@ -2,9 +2,6 @@ import paddle.fluid.dygraph as dg
|
|||
import paddle.fluid.layers as layers
|
||||
|
||||
class PreNet(dg.Layer):
|
||||
"""
|
||||
Pre Net before passing through the network
|
||||
"""
|
||||
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
|
||||
"""
|
||||
:param input_size: dimension of input
|
||||
|
@ -21,6 +18,14 @@ class PreNet(dg.Layer):
|
|||
self.linear2 = dg.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Pre Net before passing through the network.
|
||||
|
||||
Args:
|
||||
x (Variable): Shape(B, T, C), dtype: float32. The input value.
|
||||
Returns:
|
||||
x (Variable), Shape(B, T, C), the result after pernet.
|
||||
"""
|
||||
x = layers.dropout(layers.relu(self.linear1(x)), self.dropout_rate)
|
||||
x = layers.dropout(layers.relu(self.linear2(x)), self.dropout_rate)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue