Merge branch 'add_TranTTS' into 'master'
right fastspeech version. See merge request !5
This commit is contained in:
commit
abc2b5377b
|
@ -209,7 +209,7 @@ class AudioProcessor(object):
|
||||||
def inv_melspectrogram(self, mel_spectrogram):
|
def inv_melspectrogram(self, mel_spectrogram):
|
||||||
S = self._denormalize(mel_spectrogram)
|
S = self._denormalize(mel_spectrogram)
|
||||||
S = self._db_to_amplitude(S + self.ref_level_db)
|
S = self._db_to_amplitude(S + self.ref_level_db)
|
||||||
S = self._linear_to_mel(np.abs(S))
|
S = self._mel_to_linear(np.abs(S))
|
||||||
if self.preemphasis:
|
if self.preemphasis:
|
||||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||||
return self._griffin_lim(S ** self.power)
|
return self._griffin_lim(S ** self.power)
|
||||||
|
|
|
@ -1,148 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
from paddle import fluid
|
|
||||||
from parakeet import g2p
|
|
||||||
from parakeet import audio
|
|
||||||
from parakeet.data.sampler import *
|
|
||||||
from parakeet.data.datacargo import DataCargo
|
|
||||||
from parakeet.data.dataset import Dataset
|
|
||||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
|
||||||
|
|
||||||
class LJSpeechLoader:
|
|
||||||
def __init__(self, config, nranks, rank, is_vocoder=False, shuffle=True):
|
|
||||||
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
|
|
||||||
|
|
||||||
LJSPEECH_ROOT = Path(config.data_path)
|
|
||||||
dataset = LJSpeech(LJSPEECH_ROOT, config)
|
|
||||||
sampler = DistributedSampler(len(dataset), nranks, rank, shuffle=shuffle)
|
|
||||||
|
|
||||||
assert config.batch_size % nranks == 0
|
|
||||||
each_bs = config.batch_size // nranks
|
|
||||||
if is_vocoder:
|
|
||||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=shuffle, collate_fn=batch_examples_vocoder, drop_last=True)
|
|
||||||
else:
|
|
||||||
dataloader = DataCargo(dataset, sampler=sampler, batch_size=each_bs, shuffle=shuffle, collate_fn=batch_examples, drop_last=True)
|
|
||||||
|
|
||||||
self.reader = fluid.io.DataLoader.from_generator(
|
|
||||||
capacity=32,
|
|
||||||
iterable=True,
|
|
||||||
use_double_buffer=True,
|
|
||||||
return_list=True)
|
|
||||||
self.reader.set_batch_generator(dataloader, place)
|
|
||||||
|
|
||||||
|
|
||||||
class LJSpeech(Dataset):
|
|
||||||
def __init__(self, root, config):
|
|
||||||
super(LJSpeech, self).__init__()
|
|
||||||
assert isinstance(root, (str, Path)), "root should be a string or Path object"
|
|
||||||
self.root = root if isinstance(root, Path) else Path(root)
|
|
||||||
self.metadata = self._prepare_metadata()
|
|
||||||
self.config = config
|
|
||||||
self._ljspeech_processor = audio.AudioProcessor(
|
|
||||||
sample_rate=config.audio.sr,
|
|
||||||
num_mels=config.audio.num_mels,
|
|
||||||
min_level_db=config.audio.min_level_db,
|
|
||||||
ref_level_db=config.audio.ref_level_db,
|
|
||||||
n_fft=config.audio.n_fft,
|
|
||||||
win_length= config.audio.win_length,
|
|
||||||
hop_length= config.audio.hop_length,
|
|
||||||
power=config.audio.power,
|
|
||||||
preemphasis=config.audio.preemphasis,
|
|
||||||
signal_norm=True,
|
|
||||||
symmetric_norm=False,
|
|
||||||
max_norm=1.,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
clip_norm=True,
|
|
||||||
griffin_lim_iters=60,
|
|
||||||
do_trim_silence=False,
|
|
||||||
sound_norm=False)
|
|
||||||
|
|
||||||
def _prepare_metadata(self):
|
|
||||||
csv_path = self.root.joinpath("metadata.csv")
|
|
||||||
metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3,
|
|
||||||
names=["fname", "raw_text", "normalized_text"])
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _get_example(self, metadatum):
|
|
||||||
"""All the code for generating an Example from a metadatum. If you want a
|
|
||||||
different preprocessing pipeline, you can override this method.
|
|
||||||
This method may require several processor, each of which has a lot of options.
|
|
||||||
In this case, you'd better pass a composed transform and pass it to the init
|
|
||||||
method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
fname, raw_text, normalized_text = metadatum
|
|
||||||
wav_path = self.root.joinpath("wavs", fname + ".wav")
|
|
||||||
|
|
||||||
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
|
|
||||||
wav = self._ljspeech_processor.load_wav(str(wav_path))
|
|
||||||
mag = self._ljspeech_processor.spectrogram(wav).astype(np.float32)
|
|
||||||
mel = self._ljspeech_processor.melspectrogram(wav).astype(np.float32)
|
|
||||||
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
|
||||||
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
metadatum = self.metadata.iloc[index]
|
|
||||||
example = self._get_example(metadatum)
|
|
||||||
return example
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for i in range(len(self)):
|
|
||||||
yield self[i]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_examples(batch):
|
|
||||||
texts = []
|
|
||||||
mels = []
|
|
||||||
mel_inputs = []
|
|
||||||
text_lens = []
|
|
||||||
pos_texts = []
|
|
||||||
pos_mels = []
|
|
||||||
for data in batch:
|
|
||||||
_, mel, text = data
|
|
||||||
mel_inputs.append(np.concatenate([np.zeros([mel.shape[0], 1], np.float32), mel[:,:-1]], axis=-1))
|
|
||||||
text_lens.append(len(text))
|
|
||||||
pos_texts.append(np.arange(1, len(text) + 1))
|
|
||||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
|
||||||
mels.append(mel)
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
# Sort by text_len in descending order
|
|
||||||
texts = [i for i,_ in sorted(zip(texts, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
mels = [i for i,_ in sorted(zip(mels, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
mel_inputs = [i for i,_ in sorted(zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
pos_texts = [i for i,_ in sorted(zip(pos_texts, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
pos_mels = [i for i,_ in sorted(zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
text_lens = sorted(text_lens, reverse=True)
|
|
||||||
|
|
||||||
# Pad sequence with largest len of the batch
|
|
||||||
texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
|
|
||||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
|
|
||||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
|
||||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1)) #(B,T,num_mels)
|
|
||||||
mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))#(B,T,num_mels)
|
|
||||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens))
|
|
||||||
|
|
||||||
def batch_examples_vocoder(batch):
|
|
||||||
mels=[]
|
|
||||||
mags=[]
|
|
||||||
for data in batch:
|
|
||||||
mag, mel, _ = data
|
|
||||||
mels.append(mel)
|
|
||||||
mags.append(mag)
|
|
||||||
|
|
||||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1))
|
|
||||||
mags = np.transpose(SpecBatcher(pad_value=0.)(mags), axes=(0,2,1))
|
|
||||||
|
|
||||||
return (mels, mags)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -39,5 +39,8 @@ use_data_parallel: False
|
||||||
|
|
||||||
data_path: ../../../dataset/LJSpeech-1.1
|
data_path: ../../../dataset/LJSpeech-1.1
|
||||||
transtts_path: ../transformerTTS/checkpoint/
|
transtts_path: ../transformerTTS/checkpoint/
|
||||||
transformer_step: 10
|
transformer_step: 200000
|
||||||
|
save_path: ./checkpoint
|
||||||
log_dir: ./log
|
log_dir: ./log
|
||||||
|
#checkpoint_path: ./checkpoint
|
||||||
|
#ransformer_step: 97000
|
|
@ -0,0 +1,33 @@
|
||||||
|
audio:
|
||||||
|
num_mels: 80
|
||||||
|
n_fft: 2048
|
||||||
|
sr: 22050
|
||||||
|
preemphasis: 0.97
|
||||||
|
hop_length: 275
|
||||||
|
win_length: 1102
|
||||||
|
power: 1.2
|
||||||
|
min_level_db: -100
|
||||||
|
ref_level_db: 20
|
||||||
|
outputs_per_step: 1
|
||||||
|
|
||||||
|
encoder_n_layer: 6
|
||||||
|
encoder_head: 2
|
||||||
|
encoder_conv1d_filter_size: 1536
|
||||||
|
max_sep_len: 2048
|
||||||
|
decoder_n_layer: 6
|
||||||
|
decoder_head: 2
|
||||||
|
decoder_conv1d_filter_size: 1536
|
||||||
|
fs_hidden_size: 384
|
||||||
|
duration_predictor_output_size: 256
|
||||||
|
duration_predictor_filter_size: 3
|
||||||
|
fft_conv1d_filter: 3
|
||||||
|
fft_conv1d_padding: 1
|
||||||
|
dropout: 0.1
|
||||||
|
transformer_head: 4
|
||||||
|
|
||||||
|
use_gpu: True
|
||||||
|
alpha: 1.0
|
||||||
|
|
||||||
|
checkpoint_path: checkpoint/
|
||||||
|
fastspeech_step: 71000
|
||||||
|
log_dir: ./log
|
|
@ -102,7 +102,8 @@ class LengthRegulator(dg.Layer):
|
||||||
else:
|
else:
|
||||||
duration_predictor_output = layers.round(duration_predictor_output)
|
duration_predictor_output = layers.round(duration_predictor_output)
|
||||||
output = self.LR(x, duration_predictor_output, alpha)
|
output = self.LR(x, duration_predictor_output, alpha)
|
||||||
mel_pos = dg.to_variable([i+1 for i in range(output.shape[1])])
|
mel_pos = dg.to_variable(np.arange(1, output.shape[1]+1))
|
||||||
|
mel_pos = layers.unsqueeze(mel_pos, [0])
|
||||||
return output, mel_pos
|
return output, mel_pos
|
||||||
|
|
||||||
class DurationPredictor(dg.Layer):
|
class DurationPredictor(dg.Layer):
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from utils import *
|
|
||||||
from modules import FFTBlock, LengthRegulator
|
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
from parakeet.g2p.text.symbols import symbols
|
from parakeet.g2p.text.symbols import symbols
|
||||||
from parakeet.modules.utils import *
|
from parakeet.modules.utils import *
|
||||||
from parakeet.modules.post_convnet import PostConvNet
|
from parakeet.modules.post_convnet import PostConvNet
|
||||||
from parakeet.modules.layers import Linear
|
from parakeet.modules.layers import Linear
|
||||||
|
from utils import *
|
||||||
|
from modules import FFTBlock, LengthRegulator
|
||||||
|
|
||||||
class Encoder(dg.Layer):
|
class Encoder(dg.Layer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -203,8 +203,7 @@ class FastSpeech(dg.Layer):
|
||||||
return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list
|
return mel_output, mel_output_postnet, duration_predictor_output, enc_slf_attn_list, dec_slf_attn_list
|
||||||
else:
|
else:
|
||||||
length_regulator_output, decoder_pos = self.length_regulator(encoder_output, alpha=alpha)
|
length_regulator_output, decoder_pos = self.length_regulator(encoder_output, alpha=alpha)
|
||||||
decoder_output = self.decoder(length_regulator_output, decoder_pos)
|
decoder_output, _ = self.decoder(length_regulator_output, decoder_pos)
|
||||||
|
|
||||||
mel_output = self.mel_linear(decoder_output)
|
mel_output = self.mel_linear(decoder_output)
|
||||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,9 @@ def add_config_options_to_parser(parser):
|
||||||
help="the dropout in network.")
|
help="the dropout in network.")
|
||||||
parser.add_argument('--transformer_head', type=int, default=4,
|
parser.add_argument('--transformer_head', type=int, default=4,
|
||||||
help="the attention head num of transformerTTS.")
|
help="the attention head num of transformerTTS.")
|
||||||
|
parser.add_argument('--alpha', type=float, default=1.0,
|
||||||
|
help="the hyperparameter to determine the length of the expanded sequence\
|
||||||
|
mel, thereby controlling the voice speed.")
|
||||||
|
|
||||||
parser.add_argument('--hidden_size', type=int, default=256,
|
parser.add_argument('--hidden_size', type=int, default=256,
|
||||||
help="the hidden size in model of transformerTTS.")
|
help="the hidden size in model of transformerTTS.")
|
||||||
|
@ -68,6 +71,8 @@ def add_config_options_to_parser(parser):
|
||||||
help="the learning rate for training.")
|
help="the learning rate for training.")
|
||||||
parser.add_argument('--save_step', type=int, default=500,
|
parser.add_argument('--save_step', type=int, default=500,
|
||||||
help="checkpointing interval during training.")
|
help="checkpointing interval during training.")
|
||||||
|
parser.add_argument('--fastspeech_step', type=int, default=160000,
|
||||||
|
help="Global step to restore checkpoint of fastspeech.")
|
||||||
parser.add_argument('--use_gpu', type=bool, default=True,
|
parser.add_argument('--use_gpu', type=bool, default=True,
|
||||||
help="use gpu or not during training.")
|
help="use gpu or not during training.")
|
||||||
parser.add_argument('--use_data_parallel', type=bool, default=False,
|
parser.add_argument('--use_data_parallel', type=bool, default=False,
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
import os
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from collections import OrderedDict
|
||||||
|
import jsonargparse
|
||||||
|
from parse import add_config_options_to_parser
|
||||||
|
from pprint import pprint
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.dygraph as dg
|
||||||
|
from parakeet.g2p.en import text_to_sequence
|
||||||
|
from parakeet import audio
|
||||||
|
from network import FastSpeech
|
||||||
|
|
||||||
|
def load_checkpoint(step, model_path):
|
||||||
|
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
|
new_state_dict = OrderedDict()
|
||||||
|
for param in model_dict:
|
||||||
|
if param.startswith('_layers.'):
|
||||||
|
new_state_dict[param[8:]] = model_dict[param]
|
||||||
|
else:
|
||||||
|
new_state_dict[param] = model_dict[param]
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
def synthesis(text_input, cfg):
|
||||||
|
place = (fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace())
|
||||||
|
|
||||||
|
# tensorboard
|
||||||
|
if not os.path.exists(cfg.log_dir):
|
||||||
|
os.mkdir(cfg.log_dir)
|
||||||
|
path = os.path.join(cfg.log_dir,'synthesis')
|
||||||
|
|
||||||
|
writer = SummaryWriter(path)
|
||||||
|
|
||||||
|
with dg.guard(place):
|
||||||
|
model = FastSpeech(cfg)
|
||||||
|
model.set_dict(load_checkpoint(str(cfg.fastspeech_step), os.path.join(cfg.checkpoint_path, "fastspeech")))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
text = np.asarray(text_to_sequence(text_input))
|
||||||
|
text = fluid.layers.unsqueeze(dg.to_variable(text),[0])
|
||||||
|
pos_text = np.arange(1, text.shape[1]+1)
|
||||||
|
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text),[0])
|
||||||
|
|
||||||
|
mel_output, mel_output_postnet = model(text, pos_text, alpha=cfg.alpha)
|
||||||
|
|
||||||
|
_ljspeech_processor = audio.AudioProcessor(
|
||||||
|
sample_rate=cfg.audio.sr,
|
||||||
|
num_mels=cfg.audio.num_mels,
|
||||||
|
min_level_db=cfg.audio.min_level_db,
|
||||||
|
ref_level_db=cfg.audio.ref_level_db,
|
||||||
|
n_fft=cfg.audio.n_fft,
|
||||||
|
win_length= cfg.audio.win_length,
|
||||||
|
hop_length= cfg.audio.hop_length,
|
||||||
|
power=cfg.audio.power,
|
||||||
|
preemphasis=cfg.audio.preemphasis,
|
||||||
|
signal_norm=True,
|
||||||
|
symmetric_norm=False,
|
||||||
|
max_norm=1.,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
clip_norm=True,
|
||||||
|
griffin_lim_iters=60,
|
||||||
|
do_trim_silence=False,
|
||||||
|
sound_norm=False)
|
||||||
|
|
||||||
|
mel_output_postnet = fluid.layers.transpose(fluid.layers.squeeze(mel_output_postnet,[0]), [1,0])
|
||||||
|
wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy())
|
||||||
|
writer.add_audio(text_input, wav, 0, cfg.audio.sr)
|
||||||
|
print("Synthesis completed !!!")
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = jsonargparse.ArgumentParser(description="Synthesis model", formatter_class='default_argparse')
|
||||||
|
add_config_options_to_parser(parser)
|
||||||
|
cfg = parser.parse_args('-c ./config/synthesis.yaml'.split())
|
||||||
|
synthesis("Transformer model is so fast!", cfg)
|
|
@ -5,34 +5,28 @@ import time
|
||||||
import math
|
import math
|
||||||
import jsonargparse
|
import jsonargparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from parse import add_config_options_to_parser
|
||||||
|
from pprint import pprint
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from collections import OrderedDict
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
import paddle.fluid.layers as layers
|
import paddle.fluid.layers as layers
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
from parse import add_config_options_to_parser
|
from parakeet.models.dataloader.ljspeech import LJSpeechLoader
|
||||||
from pprint import pprint
|
from parakeet.models.transformerTTS.network import TransformerTTS
|
||||||
from network import FastSpeech
|
from network import FastSpeech
|
||||||
from utils import get_alignment
|
from utils import get_alignment
|
||||||
from parakeet.models.dataloader.jlspeech import LJSpeechLoader
|
|
||||||
from parakeet.models.transformerTTS.network import TransformerTTS
|
|
||||||
|
|
||||||
class MyDataParallel(dg.parallel.DataParallel):
|
def load_checkpoint(step, model_path):
|
||||||
"""
|
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
A data parallel proxy for model.
|
new_state_dict = OrderedDict()
|
||||||
"""
|
for param in model_dict:
|
||||||
|
if param.startswith('_layers.'):
|
||||||
def __init__(self, layers, strategy):
|
new_state_dict[param[8:]] = model_dict[param]
|
||||||
super(MyDataParallel, self).__init__(layers, strategy)
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
if key in self.__dict__:
|
|
||||||
return object.__getattribute__(self, key)
|
|
||||||
elif key is "_layers":
|
|
||||||
return object.__getattribute__(self, "_sub_layers")["_layers"]
|
|
||||||
else:
|
else:
|
||||||
return getattr(
|
new_state_dict[param] = model_dict[param]
|
||||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
return new_state_dict, opti_dict
|
||||||
|
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
|
|
||||||
|
@ -57,8 +51,7 @@ def main(cfg):
|
||||||
with dg.guard(place):
|
with dg.guard(place):
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
transformerTTS = TransformerTTS(cfg)
|
transformerTTS = TransformerTTS(cfg)
|
||||||
model_path = os.path.join(cfg.transtts_path, "transformer")
|
model_dict, _ = load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.transtts_path, "transformer"))
|
||||||
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, str(cfg.transformer_step)))
|
|
||||||
|
|
||||||
transformerTTS.set_dict(model_dict)
|
transformerTTS.set_dict(model_dict)
|
||||||
transformerTTS.eval()
|
transformerTTS.eval()
|
||||||
|
@ -67,27 +60,29 @@ def main(cfg):
|
||||||
model.train()
|
model.train()
|
||||||
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
|
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg.warm_up_step *( cfg.lr ** 2)), cfg.warm_up_step),
|
||||||
parameter_list=model.parameters())
|
parameter_list=model.parameters())
|
||||||
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
|
reader = LJSpeechLoader(cfg, nranks, local_rank, shuffle=True).reader()
|
||||||
|
|
||||||
if cfg.checkpoint_path is not None:
|
if cfg.checkpoint_path is not None:
|
||||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
|
model_dict, opti_dict = load_checkpoint(str(cfg.fastspeech_step), os.path.join(cfg.checkpoint_path, "fastspeech"))
|
||||||
model.set_dict(model_dict)
|
model.set_dict(model_dict)
|
||||||
optimizer.set_dict(opti_dict)
|
optimizer.set_dict(opti_dict)
|
||||||
|
global_step = cfg.fastspeech_step
|
||||||
print("load checkpoint!!!")
|
print("load checkpoint!!!")
|
||||||
|
|
||||||
if cfg.use_data_parallel:
|
if cfg.use_data_parallel:
|
||||||
strategy = dg.parallel.prepare_context()
|
strategy = dg.parallel.prepare_context()
|
||||||
model = MyDataParallel(model, strategy)
|
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||||
|
|
||||||
for epoch in range(cfg.epochs):
|
for epoch in range(cfg.epochs):
|
||||||
pbar = tqdm(reader)
|
pbar = tqdm(reader)
|
||||||
|
|
||||||
for i, data in enumerate(pbar):
|
for i, data in enumerate(pbar):
|
||||||
pbar.set_description('Processing at epoch %d'%epoch)
|
pbar.set_description('Processing at epoch %d'%epoch)
|
||||||
character, mel, mel_input, pos_text, pos_mel, text_length = data
|
character, mel, mel_input, pos_text, pos_mel, text_length, mel_lens = data
|
||||||
|
|
||||||
_, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel)
|
_, _, attn_probs, _, _, _ = transformerTTS(character, mel_input, pos_text, pos_mel)
|
||||||
alignment = dg.to_variable(get_alignment(attn_probs, cfg.transformer_head)).astype(np.float32)
|
alignment = dg.to_variable(get_alignment(attn_probs, mel_lens, cfg.transformer_head)).astype(np.float32)
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
#Forward
|
#Forward
|
||||||
|
@ -102,7 +97,6 @@ def main(cfg):
|
||||||
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
||||||
|
|
||||||
if local_rank==0:
|
if local_rank==0:
|
||||||
#print('epoch:{}, step:{}, mel_loss:{}, mel_postnet_loss:{}, duration_loss:{}'.format(epoch, global_step, mel_loss.numpy(), mel_postnet_loss.numpy(), duration_loss.numpy()))
|
|
||||||
writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
|
writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
|
||||||
writer.add_scalar('post_mel_loss', mel_postnet_loss.numpy(), global_step)
|
writer.add_scalar('post_mel_loss', mel_postnet_loss.numpy(), global_step)
|
||||||
writer.add_scalar('duration_loss', duration_loss.numpy(), global_step)
|
writer.add_scalar('duration_loss', duration_loss.numpy(), global_step)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def get_alignment(attn_probs, n_head):
|
def get_alignment(attn_probs, mel_lens, n_head):
|
||||||
max_F = 0
|
max_F = 0
|
||||||
assert attn_probs[0].shape[0] % n_head == 0
|
assert attn_probs[0].shape[0] % n_head == 0
|
||||||
batch_size = int(attn_probs[0].shape[0] // n_head)
|
batch_size = int(attn_probs[0].shape[0] // n_head)
|
||||||
|
#max_attn = attn_probs[0].numpy()[0,batch_size]
|
||||||
for i in range(len(attn_probs)):
|
for i in range(len(attn_probs)):
|
||||||
multi_attn = attn_probs[i].numpy()
|
multi_attn = attn_probs[i].numpy()
|
||||||
for j in range(n_head):
|
for j in range(n_head):
|
||||||
|
@ -12,7 +13,7 @@ def get_alignment(attn_probs, n_head):
|
||||||
if max_F < F:
|
if max_F < F:
|
||||||
max_F = F
|
max_F = F
|
||||||
max_attn = attn
|
max_attn = attn
|
||||||
alignment = compute_duration(max_attn)
|
alignment = compute_duration(max_attn, mel_lens)
|
||||||
return alignment
|
return alignment
|
||||||
|
|
||||||
def score_F(attn):
|
def score_F(attn):
|
||||||
|
@ -20,11 +21,12 @@ def score_F(attn):
|
||||||
mean = np.mean(max)
|
mean = np.mean(max)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
def compute_duration(attn):
|
def compute_duration(attn, mel_lens):
|
||||||
alignment = np.zeros([attn.shape[0],attn.shape[2]])
|
alignment = np.zeros([attn.shape[0],attn.shape[2]])
|
||||||
|
mel_lens = mel_lens.numpy()
|
||||||
for i in range(attn.shape[0]):
|
for i in range(attn.shape[0]):
|
||||||
for j in range(attn.shape[1]):
|
for j in range(mel_lens[i]):
|
||||||
max_index = attn[i,j].tolist().index(attn[i,j].max())
|
max_index = np.argmax(attn[i,j])
|
||||||
alignment[i,max_index] += 1
|
alignment[i,max_index] += 1
|
||||||
|
|
||||||
return alignment
|
return alignment
|
||||||
|
|
|
@ -24,11 +24,12 @@ save_step: 1000
|
||||||
image_step: 2000
|
image_step: 2000
|
||||||
use_gpu: True
|
use_gpu: True
|
||||||
use_data_parallel: False
|
use_data_parallel: False
|
||||||
|
stop_token: False
|
||||||
|
|
||||||
data_path: ../../../dataset/LJSpeech-1.1
|
data_path: ../../../dataset/LJSpeech-1.1
|
||||||
save_path: ./checkpoint
|
save_path: ./checkpoint
|
||||||
log_dir: ./log
|
log_dir: ./log
|
||||||
#checkpoint_path: ./checkpoint
|
#checkpoint_path: ./checkpoint
|
||||||
#transformer_step: 70000
|
#ransformer_step: 97000
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ class EncoderPrenet(dg.Layer):
|
||||||
x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2)
|
x = layers.dropout(layers.relu(batch_norm(conv(x))), 0.2)
|
||||||
x = layers.transpose(x,[0,2,1]) #(N,T,C)
|
x = layers.transpose(x,[0,2,1]) #(N,T,C)
|
||||||
x = self.projection(x)
|
x = self.projection(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class CBHG(dg.Layer):
|
class CBHG(dg.Layer):
|
||||||
|
|
|
@ -44,13 +44,15 @@ def add_config_options_to_parser(parser):
|
||||||
parser.add_argument('--max_len', type=int, default=400,
|
parser.add_argument('--max_len', type=int, default=400,
|
||||||
help="The max length of audio when synthsis.")
|
help="The max length of audio when synthsis.")
|
||||||
parser.add_argument('--transformer_step', type=int, default=160000,
|
parser.add_argument('--transformer_step', type=int, default=160000,
|
||||||
help="Global step to restore checkpoint of transformer in synthesis.")
|
help="Global step to restore checkpoint of transformer.")
|
||||||
parser.add_argument('--postnet_step', type=int, default=100000,
|
parser.add_argument('--postnet_step', type=int, default=90000,
|
||||||
help="Global step to restore checkpoint of postnet in synthesis.")
|
help="Global step to restore checkpoint of postnet.")
|
||||||
parser.add_argument('--use_gpu', type=bool, default=True,
|
parser.add_argument('--use_gpu', type=bool, default=True,
|
||||||
help="use gpu or not during training.")
|
help="use gpu or not during training.")
|
||||||
parser.add_argument('--use_data_parallel', type=bool, default=False,
|
parser.add_argument('--use_data_parallel', type=bool, default=False,
|
||||||
help="use data parallel or not during training.")
|
help="use data parallel or not during training.")
|
||||||
|
parser.add_argument('--stop_token', type=bool, default=False,
|
||||||
|
help="use stop token loss in network or not.")
|
||||||
|
|
||||||
parser.add_argument('--data_path', type=str, default='./dataset/LJSpeech-1.1',
|
parser.add_argument('--data_path', type=str, default='./dataset/LJSpeech-1.1',
|
||||||
help="the path of dataset.")
|
help="the path of dataset.")
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
from parakeet import g2p
|
|
||||||
from parakeet import audio
|
|
||||||
|
|
||||||
from parakeet.data.sampler import SequentialSampler, RandomSampler, BatchSampler
|
|
||||||
from parakeet.data.dataset import Dataset
|
|
||||||
from parakeet.data.datacargo import DataCargo
|
|
||||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
|
||||||
|
|
||||||
_ljspeech_processor = audio.AudioProcessor(
|
|
||||||
sample_rate=22050,
|
|
||||||
num_mels=80,
|
|
||||||
min_level_db=-100,
|
|
||||||
ref_level_db=20,
|
|
||||||
n_fft=2048,
|
|
||||||
win_length= int(22050 * 0.05),
|
|
||||||
hop_length= int(22050 * 0.0125),
|
|
||||||
power=1.2,
|
|
||||||
preemphasis=0.97,
|
|
||||||
signal_norm=True,
|
|
||||||
symmetric_norm=False,
|
|
||||||
max_norm=1.,
|
|
||||||
mel_fmin=0,
|
|
||||||
mel_fmax=None,
|
|
||||||
clip_norm=True,
|
|
||||||
griffin_lim_iters=60,
|
|
||||||
do_trim_silence=False,
|
|
||||||
sound_norm=False)
|
|
||||||
|
|
||||||
class LJSpeech(Dataset):
|
|
||||||
def __init__(self, root):
|
|
||||||
super(LJSpeech, self).__init__()
|
|
||||||
assert isinstance(root, (str, Path)), "root should be a string or Path object"
|
|
||||||
self.root = root if isinstance(root, Path) else Path(root)
|
|
||||||
self.metadata = self._prepare_metadata()
|
|
||||||
|
|
||||||
def _prepare_metadata(self):
|
|
||||||
csv_path = self.root.joinpath("metadata.csv")
|
|
||||||
metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3,
|
|
||||||
names=["fname", "raw_text", "normalized_text"])
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _get_example(self, metadatum):
|
|
||||||
"""All the code for generating an Example from a metadatum. If you want a
|
|
||||||
different preprocessing pipeline, you can override this method.
|
|
||||||
This method may require several processor, each of which has a lot of options.
|
|
||||||
In this case, you'd better pass a composed transform and pass it to the init
|
|
||||||
method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
fname, raw_text, normalized_text = metadatum
|
|
||||||
wav_path = self.root.joinpath("wavs", fname + ".wav")
|
|
||||||
|
|
||||||
# load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize
|
|
||||||
wav = _ljspeech_processor.load_wav(str(wav_path))
|
|
||||||
mag = _ljspeech_processor.spectrogram(wav).astype(np.float32)
|
|
||||||
mel = _ljspeech_processor.melspectrogram(wav).astype(np.float32)
|
|
||||||
phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
|
||||||
return (mag, mel, phonemes) # maybe we need to implement it as a map in the future
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
metadatum = self.metadata.iloc[index]
|
|
||||||
example = self._get_example(metadatum)
|
|
||||||
return example
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for i in range(len(self)):
|
|
||||||
yield self[i]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.metadata)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_examples(batch):
|
|
||||||
texts = []
|
|
||||||
mels = []
|
|
||||||
mel_inputs = []
|
|
||||||
text_lens = []
|
|
||||||
pos_texts = []
|
|
||||||
pos_mels = []
|
|
||||||
for data in batch:
|
|
||||||
_, mel, text = data
|
|
||||||
mel_inputs.append(np.concatenate([np.zeros([mel.shape[0], 1], np.float32), mel[:,:-1]], axis=-1))
|
|
||||||
text_lens.append(len(text))
|
|
||||||
pos_texts.append(np.arange(1, len(text) + 1))
|
|
||||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
|
||||||
mels.append(mel)
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
# Sort by text_len in descending order
|
|
||||||
texts = [i for i,_ in sorted(zip(texts, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
mels = [i for i,_ in sorted(zip(mels, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
mel_inputs = [i for i,_ in sorted(zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
pos_texts = [i for i,_ in sorted(zip(pos_texts, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
pos_mels = [i for i,_ in sorted(zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True)]
|
|
||||||
text_lens = sorted(text_lens, reverse=True)
|
|
||||||
|
|
||||||
# Pad sequence with largest len of the batch
|
|
||||||
texts = TextIDBatcher(pad_id=0)(texts)
|
|
||||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts)
|
|
||||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels)
|
|
||||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1))
|
|
||||||
mel_inputs = np.transpose(SpecBatcher(pad_value=0.)(mel_inputs), axes=(0,2,1))
|
|
||||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens))
|
|
||||||
|
|
||||||
def batch_examples_vocoder(batch):
|
|
||||||
mels=[]
|
|
||||||
mags=[]
|
|
||||||
for data in batch:
|
|
||||||
mag, mel, _ = data
|
|
||||||
mels.append(mel)
|
|
||||||
mags.append(mag)
|
|
||||||
|
|
||||||
mels = np.transpose(SpecBatcher(pad_value=0.)(mels), axes=(0,2,1))
|
|
||||||
mags = np.transpose(SpecBatcher(pad_value=0.)(mags), axes=(0,2,1))
|
|
||||||
|
|
||||||
return (mels, mags)
|
|
||||||
|
|
||||||
|
|
|
@ -7,15 +7,22 @@ from tqdm import tqdm
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from preprocess import _ljspeech_processor
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import jsonargparse
|
import jsonargparse
|
||||||
from parse import add_config_options_to_parser
|
from parse import add_config_options_to_parser
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
from collections import OrderedDict
|
||||||
|
from parakeet import audio
|
||||||
|
|
||||||
def load_checkpoint(step, model_path):
|
def load_checkpoint(step, model_path):
|
||||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
return model_dict
|
new_state_dict = OrderedDict()
|
||||||
|
for param in model_dict:
|
||||||
|
if param.startswith('_layers.'):
|
||||||
|
new_state_dict[param[8:]] = model_dict[param]
|
||||||
|
else:
|
||||||
|
new_state_dict[param] = model_dict[param]
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
def synthesis(text_input, cfg):
|
def synthesis(text_input, cfg):
|
||||||
place = (fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace())
|
place = (fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace())
|
||||||
|
@ -30,7 +37,7 @@ def synthesis(text_input, cfg):
|
||||||
with dg.guard(place):
|
with dg.guard(place):
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
model = TransformerTTS(cfg)
|
model = TransformerTTS(cfg)
|
||||||
model.set_dict(load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.checkpoint_path, "transformer")))
|
model.set_dict(load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.checkpoint_path, "nostop_token/transformer")))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
|
@ -54,11 +61,32 @@ def synthesis(text_input, cfg):
|
||||||
mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1)
|
mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1)
|
||||||
mag_pred = model_postnet(postnet_pred)
|
mag_pred = model_postnet(postnet_pred)
|
||||||
|
|
||||||
|
_ljspeech_processor = audio.AudioProcessor(
|
||||||
|
sample_rate=cfg.audio.sr,
|
||||||
|
num_mels=cfg.audio.num_mels,
|
||||||
|
min_level_db=cfg.audio.min_level_db,
|
||||||
|
ref_level_db=cfg.audio.ref_level_db,
|
||||||
|
n_fft=cfg.audio.n_fft,
|
||||||
|
win_length= cfg.audio.win_length,
|
||||||
|
hop_length= cfg.audio.hop_length,
|
||||||
|
power=cfg.audio.power,
|
||||||
|
preemphasis=cfg.audio.preemphasis,
|
||||||
|
signal_norm=True,
|
||||||
|
symmetric_norm=False,
|
||||||
|
max_norm=1.,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=None,
|
||||||
|
clip_norm=True,
|
||||||
|
griffin_lim_iters=60,
|
||||||
|
do_trim_silence=False,
|
||||||
|
sound_norm=False)
|
||||||
|
|
||||||
wav = _ljspeech_processor.inv_spectrogram(fluid.layers.transpose(fluid.layers.squeeze(mag_pred,[0]), [1,0]).numpy())
|
wav = _ljspeech_processor.inv_spectrogram(fluid.layers.transpose(fluid.layers.squeeze(mag_pred,[0]), [1,0]).numpy())
|
||||||
writer.add_audio(text_input, wav, 0, cfg.audio.sr)
|
writer.add_audio(text_input, wav, 0, cfg.audio.sr)
|
||||||
if not os.path.exists(cfg.sample_path):
|
if not os.path.exists(cfg.sample_path):
|
||||||
os.mkdir(cfg.sample_path)
|
os.mkdir(cfg.sample_path)
|
||||||
write(os.path.join(cfg.sample_path,'test.wav'), cfg.audio.sr, wav)
|
write(os.path.join(cfg.sample_path,'test.wav'), cfg.audio.sr, wav)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = jsonargparse.ArgumentParser(description="Synthesis model", formatter_class='default_argparse')
|
parser = jsonargparse.ArgumentParser(description="Synthesis model", formatter_class='default_argparse')
|
||||||
|
|
|
@ -1,33 +1,23 @@
|
||||||
from network import *
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from collections import OrderedDict
|
||||||
import jsonargparse
|
import jsonargparse
|
||||||
from parse import add_config_options_to_parser
|
from parse import add_config_options_to_parser
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from parakeet.models.dataloader.jlspeech import LJSpeechLoader
|
from parakeet.models.dataloader.ljspeech import LJSpeechLoader
|
||||||
|
from network import *
|
||||||
class MyDataParallel(dg.parallel.DataParallel):
|
|
||||||
"""
|
|
||||||
A data parallel proxy for model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, layers, strategy):
|
|
||||||
super(MyDataParallel, self).__init__(layers, strategy)
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
if key in self.__dict__:
|
|
||||||
return object.__getattribute__(self, key)
|
|
||||||
elif key is "_layers":
|
|
||||||
return object.__getattribute__(self, "_sub_layers")["_layers"]
|
|
||||||
else:
|
|
||||||
return getattr(
|
|
||||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
|
||||||
|
|
||||||
def load_checkpoint(step, model_path):
|
def load_checkpoint(step, model_path):
|
||||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
return model_dict, opti_dict
|
new_state_dict = OrderedDict()
|
||||||
|
for param in model_dict:
|
||||||
|
if param.startswith('_layers.'):
|
||||||
|
new_state_dict[param[8:]] = model_dict[param]
|
||||||
|
else:
|
||||||
|
new_state_dict[param] = model_dict[param]
|
||||||
|
return new_state_dict, opti_dict
|
||||||
|
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
|
|
||||||
|
@ -66,7 +56,7 @@ def main(cfg):
|
||||||
|
|
||||||
if cfg.use_data_parallel:
|
if cfg.use_data_parallel:
|
||||||
strategy = dg.parallel.prepare_context()
|
strategy = dg.parallel.prepare_context()
|
||||||
model = MyDataParallel(model, strategy)
|
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||||
|
|
||||||
reader = LJSpeechLoader(cfg, nranks, local_rank, is_vocoder=True).reader()
|
reader = LJSpeechLoader(cfg, nranks, local_rank, is_vocoder=True).reader()
|
||||||
|
|
||||||
|
|
|
@ -1,46 +1,33 @@
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import paddle.fluid.dygraph as dg
|
|
||||||
import paddle.fluid.layers as layers
|
|
||||||
from network import *
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from collections import OrderedDict
|
||||||
import jsonargparse
|
import jsonargparse
|
||||||
from parse import add_config_options_to_parser
|
from parse import add_config_options_to_parser
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
|
import paddle.fluid.dygraph as dg
|
||||||
|
import paddle.fluid.layers as layers
|
||||||
from parakeet.modules.utils import cross_entropy
|
from parakeet.modules.utils import cross_entropy
|
||||||
from parakeet.models.dataloader.jlspeech import LJSpeechLoader
|
from parakeet.models.dataloader.ljspeech import LJSpeechLoader
|
||||||
|
from network import *
|
||||||
class MyDataParallel(dg.parallel.DataParallel):
|
|
||||||
"""
|
|
||||||
A data parallel proxy for model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, layers, strategy):
|
|
||||||
super(MyDataParallel, self).__init__(layers, strategy)
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
if key in self.__dict__:
|
|
||||||
return object.__getattribute__(self, key)
|
|
||||||
elif key is "_layers":
|
|
||||||
return object.__getattribute__(self, "_sub_layers")["_layers"]
|
|
||||||
else:
|
|
||||||
return getattr(
|
|
||||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
|
||||||
|
|
||||||
def load_checkpoint(step, model_path):
|
def load_checkpoint(step, model_path):
|
||||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
return model_dict, opti_dict
|
new_state_dict = OrderedDict()
|
||||||
|
for param in model_dict:
|
||||||
|
if param.startswith('_layers.'):
|
||||||
|
new_state_dict[param[8:]] = model_dict[param]
|
||||||
|
else:
|
||||||
|
new_state_dict[param] = model_dict[param]
|
||||||
|
return new_state_dict, opti_dict
|
||||||
|
|
||||||
|
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
||||||
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
|
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
|
||||||
|
|
||||||
fluid.default_startup_program().random_seed = 1
|
|
||||||
fluid.default_main_program().random_seed = 1
|
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
# Print the whole config setting.
|
# Print the whole config setting.
|
||||||
pprint(jsonargparse.namespace_to_dict(cfg))
|
pprint(jsonargparse.namespace_to_dict(cfg))
|
||||||
|
@ -74,28 +61,27 @@ def main(cfg):
|
||||||
|
|
||||||
if cfg.use_data_parallel:
|
if cfg.use_data_parallel:
|
||||||
strategy = dg.parallel.prepare_context()
|
strategy = dg.parallel.prepare_context()
|
||||||
model = MyDataParallel(model, strategy)
|
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||||
|
|
||||||
for epoch in range(cfg.epochs):
|
for epoch in range(cfg.epochs):
|
||||||
pbar = tqdm(reader)
|
pbar = tqdm(reader)
|
||||||
for i, data in enumerate(pbar):
|
for i, data in enumerate(pbar):
|
||||||
pbar.set_description('Processing at epoch %d'%epoch)
|
pbar.set_description('Processing at epoch %d'%epoch)
|
||||||
character, mel, mel_input, pos_text, pos_mel, text_length = data
|
character, mel, mel_input, pos_text, pos_mel, text_length, _ = data
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel)
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel)
|
||||||
|
|
||||||
|
|
||||||
label = (pos_mel == 0).astype(np.float32)
|
label = (pos_mel == 0).astype(np.float32)
|
||||||
#label = np.zeros(stop_preds.shape).astype(np.float32)
|
|
||||||
#text_length = text_length.numpy()
|
|
||||||
#for i in range(label.shape[0]):
|
|
||||||
# label[i][text_length[i] - 1] = 1
|
|
||||||
|
|
||||||
mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
||||||
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||||
|
loss = mel_loss + post_mel_loss
|
||||||
|
# Note: When used stop token loss the learning did not work.
|
||||||
|
if cfg.stop_token:
|
||||||
stop_loss = cross_entropy(stop_preds, label)
|
stop_loss = cross_entropy(stop_preds, label)
|
||||||
loss = mel_loss + post_mel_loss + stop_loss
|
loss = loss + stop_loss
|
||||||
|
|
||||||
if local_rank==0:
|
if local_rank==0:
|
||||||
writer.add_scalars('training_loss', {
|
writer.add_scalars('training_loss', {
|
||||||
|
|
Loading…
Reference in New Issue