completed fastspeech and modified save/load
This commit is contained in:
parent
f312b2f05c
commit
c1b837dc17
|
@ -0,0 +1,52 @@
|
|||
data:
|
||||
batch_size: 8
|
||||
train_clip_seconds: 0.5
|
||||
sample_rate: 22050
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 2048
|
||||
|
||||
n_mels: 80
|
||||
valid_size: 16
|
||||
|
||||
|
||||
conditioner:
|
||||
upsampling_factors: [16, 16]
|
||||
|
||||
teacher:
|
||||
n_loop: 10
|
||||
n_layer: 3
|
||||
filter_size: 2
|
||||
residual_channels: 128
|
||||
loss_type: "mog"
|
||||
output_dim: 3
|
||||
log_scale_min: -9
|
||||
|
||||
student:
|
||||
n_loops: [10, 10, 10, 10, 10, 10]
|
||||
n_layers: [1, 1, 1, 1, 1, 1]
|
||||
filter_size: 3
|
||||
residual_channels: 64
|
||||
log_scale_min: -7
|
||||
|
||||
stft:
|
||||
n_fft: 2048
|
||||
win_length: 1024
|
||||
hop_length: 256
|
||||
|
||||
loss:
|
||||
lmd: 4
|
||||
|
||||
train:
|
||||
learning_rate: 0.0005
|
||||
anneal_rate: 0.5
|
||||
anneal_interval: 200000
|
||||
gradient_max_norm: 100.0
|
||||
|
||||
checkpoint_interval: 1000
|
||||
eval_interval: 1000
|
||||
|
||||
max_iterations: 2000000
|
||||
|
||||
|
||||
|
|
@ -26,36 +26,69 @@ The model consists of encoder, decoder and length regulator three parts.
|
|||
├── train.py # script for model training
|
||||
```
|
||||
|
||||
## Train Transformer
|
||||
## Saving & Loading
|
||||
`train.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
|
||||
|
||||
1. `output` is the directory for saving results.
|
||||
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`.
|
||||
During synthesis, results are saved in `samples/` in `output` and tensorboard log is save in `log/` in `output`.
|
||||
|
||||
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
|
||||
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
|
||||
|
||||
## Compute Alignment
|
||||
|
||||
Before train FastSpeech model, you should have diagonal information. We use the diagonal obtained from the TranformerTTS model as the diagonal, you can run alignments/get_alignments.py to get it.
|
||||
|
||||
```bash
|
||||
cd alignments
|
||||
python get_alignments.py \
|
||||
--use_gpu=1 \
|
||||
--output='./alignments' \
|
||||
--data=${DATAPATH} \
|
||||
--config=${CONFIG} \
|
||||
--checkpoint_transformer=${CHECKPOINT} \
|
||||
```
|
||||
where `${DATAPATH}` is the path saved LJSpeech data, `${CHECKPOINT}` is the pretrain model path of TransformerTTS, `${CONFIG}` is the config yaml file of TransformerTTS checkpoint. It necessary for you to prepare a pre-trained TranformerTTS checkpoint.
|
||||
|
||||
For more help on arguments:
|
||||
``python train.py --help``.
|
||||
|
||||
Or you can use your own diagonal information, you should process the data into the following format:
|
||||
```bash
|
||||
{'fname1': alignment1,
|
||||
'fname2': alignment2,
|
||||
...}
|
||||
```
|
||||
|
||||
## Train FastSpeech
|
||||
|
||||
FastSpeech model can be trained with ``train.py``.
|
||||
```bash
|
||||
python train.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path=${DATAPATH} \
|
||||
--transtts_path='../transformer_tts/checkpoint' \
|
||||
--transformer_step=160000 \
|
||||
--config_path='config/fastspeech.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--alignments_path=${ALIGNMENTS_PATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
Or you can run the script file directly.
|
||||
```bash
|
||||
sh train.sh
|
||||
```
|
||||
If you want to train on multiple GPUs, you must set ``--use_data_parallel=1``, and then start training as follows:
|
||||
If you want to train on multiple GPUs, start training as follows:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=1 \
|
||||
--data_path=${DATAPATH} \
|
||||
--transtts_path='../transformer_tts/checkpoint' \
|
||||
--transformer_step=160000 \
|
||||
--config_path='config/fastspeech.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--alignments_path=${ALIGNMENTS_PATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--fastspeech_step``.
|
||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||
|
||||
For more help on arguments:
|
||||
``python train.py --help``.
|
||||
|
@ -66,9 +99,13 @@ After training the FastSpeech, audio can be synthesized with ``synthesis.py``.
|
|||
python synthesis.py \
|
||||
--use_gpu=1 \
|
||||
--alpha=1.0 \
|
||||
--checkpoint_path='checkpoint/' \
|
||||
--fastspeech_step=112000 \
|
||||
--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--config_clarine='../clarinet/configs/config.yaml' \
|
||||
--checkpoint_clarinet='../clarinet/checkpoint/step-500000' \
|
||||
--output='./synthesis' \
|
||||
```
|
||||
We use Clarinet to synthesis wav, so it necessary for you to prepare a pre-trained [Clarinet checkpoint](https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_ckpt_1.0.zip).
|
||||
|
||||
Or you can run the script file directly.
|
||||
```bash
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from scipy.io.wavfile import write
|
||||
from parakeet.g2p.en import text_to_sequence
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
from ruamel import yaml
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
from collections import OrderedDict
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet import audio
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.models.fastspeech.utils import get_alignment
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_transformer",
|
||||
type=str,
|
||||
help="transformer_tts checkpoint to synthesis")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="./alignments",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def alignments(args):
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
with dg.guard(place):
|
||||
network_cfg = cfg['network']
|
||||
model = TransformerTTS(
|
||||
network_cfg['embedding_size'], network_cfg['hidden_size'],
|
||||
network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
|
||||
cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
|
||||
network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model, checkpoint_path=args.checkpoint_transformer)
|
||||
model.eval()
|
||||
|
||||
# get text data
|
||||
root = Path(args.data)
|
||||
csv_path = root.joinpath("metadata.csv")
|
||||
table = pd.read_csv(
|
||||
csv_path,
|
||||
sep="|",
|
||||
header=None,
|
||||
quoting=csv.QUOTE_NONE,
|
||||
names=["fname", "raw_text", "normalized_text"])
|
||||
ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
num_mels=cfg['audio']['num_mels'],
|
||||
min_level_db=cfg['audio']['min_level_db'],
|
||||
ref_level_db=cfg['audio']['ref_level_db'],
|
||||
n_fft=cfg['audio']['n_fft'],
|
||||
win_length=cfg['audio']['win_length'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
power=cfg['audio']['power'],
|
||||
preemphasis=cfg['audio']['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
|
||||
pbar = tqdm(range(len(table)))
|
||||
alignments = OrderedDict()
|
||||
for i in pbar:
|
||||
fname, raw_text, normalized_text = table.iloc[i]
|
||||
# init input
|
||||
text = np.asarray(text_to_sequence(normalized_text))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||
wav = ljspeech_processor.load_wav(
|
||||
os.path.join(args.data, 'wavs', fname + ".wav"))
|
||||
mel_input = ljspeech_processor.melspectrogram(wav).astype(
|
||||
np.float32)
|
||||
mel_input = np.transpose(mel_input, axes=(1, 0))
|
||||
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
||||
mel_lens = mel_input.shape[1]
|
||||
|
||||
dec_slf_mask = get_triu_tensor(mel_input,
|
||||
mel_input).astype(np.float32)
|
||||
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
|
||||
dec_slf_mask = fluid.layers.cast(
|
||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
||||
mel_input = fluid.layers.concat(
|
||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||
|
||||
alignment, _ = get_alignment(attn_probs, mel_lens,
|
||||
network_cfg['decoder_num_head'])
|
||||
alignments[fname] = alignment
|
||||
with open(args.output + '.txt', "wb") as f:
|
||||
pickle.dump(alignments, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Get alignments from TransformerTTS model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
alignments(args)
|
|
@ -0,0 +1,14 @@
|
|||
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u get_alignments.py \
|
||||
--use_gpu=1 \
|
||||
--output='./alignments' \
|
||||
--data='../../../dataset/LJSpeech-1.1' \
|
||||
--config='../../transformer_tts/configs/ljspeech.yaml' \
|
||||
--checkpoint_transformer='../../transformer_tts/checkpoint/transformer/step-120000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
|
@ -1,32 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80 #the number of mel bands when calculating mel spectrograms.
|
||||
n_fft: 2048 #the number of fft components.
|
||||
sr: 22050 #the sampling rate of audio data file.
|
||||
preemphasis: 0.97 #the preemphasis coefficient.
|
||||
hop_length: 256 #the number of samples to advance between frames.
|
||||
win_length: 1024 #the length (width) of the window function.
|
||||
power: 1.2 #the power to raise before griffin-lim.
|
||||
min_level_db: -100 #the minimum level db.
|
||||
ref_level_db: 20 #the reference level db.
|
||||
outputs_per_step: 1 #the outputs per step.
|
||||
|
||||
encoder_n_layer: 6 #the number of FFT Block in encoder.
|
||||
encoder_head: 2 #the attention head number in encoder.
|
||||
encoder_conv1d_filter_size: 1536 #the filter size of conv1d in encoder.
|
||||
max_seq_len: 2048 #the max length of sequence.
|
||||
decoder_n_layer: 6 #the number of FFT Block in decoder.
|
||||
decoder_head: 2 #the attention head number in decoder.
|
||||
decoder_conv1d_filter_size: 1536 #the filter size of conv1d in decoder.
|
||||
fs_hidden_size: 384 #the hidden size in model of fastspeech.
|
||||
duration_predictor_output_size: 256 #the output size of duration predictior.
|
||||
duration_predictor_filter_size: 3 #the filter size of conv1d in duration prediction.
|
||||
fft_conv1d_filter: 3 #the filter size of conv1d in fft.
|
||||
fft_conv1d_padding: 1 #the padding size of conv1d in fft.
|
||||
dropout: 0.1 #the dropout in network.
|
||||
transformer_head: 4 #the attention head num of transformerTTS.
|
||||
|
||||
embedding_size: 512 #the dim size of embedding of transformerTTS.
|
||||
hidden_size: 256 #the hidden size in model of transformerTTS.
|
||||
warm_up_step: 4000 #the warm up step of learning rate.
|
||||
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
audio:
|
||||
num_mels: 80 #the number of mel bands when calculating mel spectrograms.
|
||||
n_fft: 2048 #the number of fft components.
|
||||
sr: 22050 #the sampling rate of audio data file.
|
||||
hop_length: 256 #the number of samples to advance between frames.
|
||||
win_length: 1024 #the length (width) of the window function.
|
||||
power: 1.2 #the power to raise before griffin-lim.
|
||||
|
||||
network:
|
||||
encoder_n_layer: 6 #the number of FFT Block in encoder.
|
||||
encoder_head: 2 #the attention head number in encoder.
|
||||
encoder_conv1d_filter_size: 1536 #the filter size of conv1d in encoder.
|
||||
max_seq_len: 2048 #the max length of sequence.
|
||||
decoder_n_layer: 6 #the number of FFT Block in decoder.
|
||||
decoder_head: 2 #the attention head number in decoder.
|
||||
decoder_conv1d_filter_size: 1536 #the filter size of conv1d in decoder.
|
||||
hidden_size: 384 #the hidden size in model of fastspeech.
|
||||
duration_predictor_output_size: 256 #the output size of duration predictior.
|
||||
duration_predictor_filter_size: 3 #the filter size of conv1d in duration prediction.
|
||||
fft_conv1d_filter: 3 #the filter size of conv1d in fft.
|
||||
fft_conv1d_padding: 1 #the padding size of conv1d in fft.
|
||||
dropout: 0.1 #the dropout in network.
|
||||
outputs_per_step: 1
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
learning_rate: 0.001
|
||||
warm_up_step: 4000 #the warm up step of learning rate.
|
||||
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
||||
|
||||
checkpoint_interval: 1000
|
||||
max_epochs: 10000
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
encoder_n_layer: 6
|
||||
encoder_head: 2
|
||||
encoder_conv1d_filter_size: 1536
|
||||
max_seq_len: 2048
|
||||
decoder_n_layer: 6
|
||||
decoder_head: 2
|
||||
decoder_conv1d_filter_size: 1536
|
||||
fs_hidden_size: 384
|
||||
duration_predictor_output_size: 256
|
||||
duration_predictor_filter_size: 3
|
||||
fft_conv1d_filter: 3
|
||||
fft_conv1d_padding: 1
|
||||
dropout: 0.1
|
||||
transformer_head: 4
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import librosa
|
||||
import csv
|
||||
import pickle
|
||||
|
||||
from paddle import fluid
|
||||
from parakeet import g2p
|
||||
from parakeet import audio
|
||||
from parakeet.data.sampler import *
|
||||
from parakeet.data.datacargo import DataCargo
|
||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
||||
from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset, SliceDataset
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
|
||||
|
||||
class LJSpeechLoader:
|
||||
def __init__(self,
|
||||
config,
|
||||
place,
|
||||
data_path,
|
||||
alignments_path,
|
||||
batch_size,
|
||||
nranks,
|
||||
rank,
|
||||
is_vocoder=False,
|
||||
shuffle=True):
|
||||
|
||||
LJSPEECH_ROOT = Path(data_path)
|
||||
metadata = LJSpeechMetaData(LJSPEECH_ROOT, alignments_path)
|
||||
transformer = LJSpeech(
|
||||
sr=config['sr'],
|
||||
n_fft=config['n_fft'],
|
||||
num_mels=config['num_mels'],
|
||||
win_length=config['win_length'],
|
||||
hop_length=config['hop_length'])
|
||||
dataset = TransformDataset(metadata, transformer)
|
||||
dataset = CacheDataset(dataset)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
len(dataset), nranks, rank, shuffle=shuffle)
|
||||
|
||||
assert batch_size % nranks == 0
|
||||
each_bs = batch_size // nranks
|
||||
dataloader = DataCargo(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=each_bs,
|
||||
shuffle=shuffle,
|
||||
batch_fn=batch_examples,
|
||||
drop_last=True)
|
||||
self.reader = fluid.io.DataLoader.from_generator(
|
||||
capacity=32,
|
||||
iterable=True,
|
||||
use_double_buffer=True,
|
||||
return_list=True)
|
||||
self.reader.set_batch_generator(dataloader, place)
|
||||
|
||||
|
||||
class LJSpeechMetaData(DatasetMixin):
|
||||
def __init__(self, root, alignments_path):
|
||||
self.root = Path(root)
|
||||
self._wav_dir = self.root.joinpath("wavs")
|
||||
csv_path = self.root.joinpath("metadata.csv")
|
||||
self._table = pd.read_csv(
|
||||
csv_path,
|
||||
sep="|",
|
||||
header=None,
|
||||
quoting=csv.QUOTE_NONE,
|
||||
names=["fname", "raw_text", "normalized_text"])
|
||||
with open(alignments_path, "rb") as f:
|
||||
self._alignments = pickle.load(f)
|
||||
|
||||
def get_example(self, i):
|
||||
fname, raw_text, normalized_text = self._table.iloc[i]
|
||||
alignment = self._alignments[fname]
|
||||
fname = str(self._wav_dir.joinpath(fname + ".wav"))
|
||||
return fname, normalized_text, alignment
|
||||
|
||||
def __len__(self):
|
||||
return len(self._table)
|
||||
|
||||
|
||||
class LJSpeech(object):
|
||||
def __init__(self,
|
||||
sr=22050,
|
||||
n_fft=2048,
|
||||
num_mels=80,
|
||||
win_length=1024,
|
||||
hop_length=256):
|
||||
super(LJSpeech, self).__init__()
|
||||
self.sr = sr
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(self, metadatum):
|
||||
"""All the code for generating an Example from a metadatum. If you want a
|
||||
different preprocessing pipeline, you can override this method.
|
||||
This method may require several processor, each of which has a lot of options.
|
||||
In this case, you'd better pass a composed transform and pass it to the init
|
||||
method.
|
||||
"""
|
||||
fname, normalized_text, alignment = metadatum
|
||||
|
||||
wav, _ = librosa.load(str(fname))
|
||||
spec = librosa.stft(
|
||||
y=wav,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length)
|
||||
mag = np.abs(spec)
|
||||
mel = librosa.filters.mel(self.sr, self.n_fft, n_mels=self.num_mels)
|
||||
mel = np.matmul(mel, mag)
|
||||
mel = np.log(np.maximum(mel, 1e-5))
|
||||
phonemes = np.array(
|
||||
g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||
return (mel, phonemes, alignment
|
||||
) # maybe we need to implement it as a map in the future
|
||||
|
||||
|
||||
def batch_examples(batch):
|
||||
texts = []
|
||||
mels = []
|
||||
text_lens = []
|
||||
pos_texts = []
|
||||
pos_mels = []
|
||||
alignments = []
|
||||
for data in batch:
|
||||
mel, text, alignment = data
|
||||
text_lens.append(len(text))
|
||||
pos_texts.append(np.arange(1, len(text) + 1))
|
||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
||||
mels.append(mel)
|
||||
texts.append(text)
|
||||
alignments.append(alignment)
|
||||
|
||||
# Sort by text_len in descending order
|
||||
texts = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(texts, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
mels = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(mels, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
pos_texts = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(pos_texts, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
pos_mels = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
alignments = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(alignments, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
#text_lens = sorted(text_lens, reverse=True)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
|
||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
|
||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
||||
alignments = TextIDBatcher(pad_id=0)(alignments).astype(np.float32)
|
||||
mels = np.transpose(
|
||||
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||
|
||||
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32)
|
||||
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
|
||||
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels, mels).astype(np.float32)
|
||||
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
|
||||
return (texts, mels, pos_texts, pos_mels, enc_slf_mask, enc_query_mask,
|
||||
dec_slf_mask, dec_query_slf_mask, alignments)
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
default='configs/fastspeech.yaml',
|
||||
help="the yaml config file path.")
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, default=32, help="batch size for training.")
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10000,
|
||||
help="the number of epoch for training.")
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="the learning rate for training.")
|
||||
parser.add_argument(
|
||||
'--save_step',
|
||||
type=int,
|
||||
default=500,
|
||||
help="checkpointing interval during training.")
|
||||
parser.add_argument(
|
||||
'--fastspeech_step',
|
||||
type=int,
|
||||
default=70000,
|
||||
help="Global step to restore checkpoint of fastspeech.")
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=int,
|
||||
default=1,
|
||||
help="use gpu or not during training.")
|
||||
parser.add_argument(
|
||||
'--use_data_parallel',
|
||||
type=int,
|
||||
default=0,
|
||||
help="use data parallel or not during training.")
|
||||
parser.add_argument(
|
||||
'--alpha',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The hyperparameter to determine the length of the expanded sequence \
|
||||
mel, thereby controlling the voice speed.")
|
||||
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='./dataset/LJSpeech-1.1',
|
||||
help="the path of dataset.")
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="the path to load checkpoint or pretrain model.")
|
||||
parser.add_argument(
|
||||
'--save_path',
|
||||
type=str,
|
||||
default='./checkpoint',
|
||||
help="the path to save checkpoint.")
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='./log',
|
||||
help="the directory to save tensorboard log.")
|
||||
parser.add_argument(
|
||||
'--sample_path',
|
||||
type=str,
|
||||
default='./sample',
|
||||
help="the directory to save audio sample in synthesis.")
|
||||
parser.add_argument(
|
||||
'--transtts_path',
|
||||
type=str,
|
||||
default='../transformer_tts/checkpoint',
|
||||
help="the directory to load pretrain transformerTTS model.")
|
||||
parser.add_argument(
|
||||
'--transformer_step',
|
||||
type=int,
|
||||
default=160000,
|
||||
help="the step to load transformerTTS model.")
|
|
@ -13,9 +13,9 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from tensorboardX import SummaryWriter
|
||||
from scipy.io.wavfile import write
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from matplotlib import cm
|
||||
|
@ -26,38 +26,56 @@ from parakeet.g2p.en import text_to_sequence
|
|||
from parakeet import audio
|
||||
from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||
from parakeet.utils.layer_tools import summary, freeze
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument(
|
||||
"--config_clarinet", type=str, help="path of the clarinet config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=1,
|
||||
help="determine the length of the expanded sequence mel, controlling the voice speed."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, help="fastspeech checkpoint to synthesis")
|
||||
parser.add_argument(
|
||||
"--checkpoint_clarinet",
|
||||
type=str,
|
||||
help="clarinet checkpoint to synthesis")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="synthesis",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def synthesis(text_input, args):
|
||||
place = (fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
|
||||
|
||||
# tensorboard
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'synthesis')
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
writer = SummaryWriter(path)
|
||||
# tensorboard
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
||||
|
||||
with dg.guard(place):
|
||||
model = FastSpeech(cfg)
|
||||
model.set_dict(
|
||||
load_checkpoint(
|
||||
str(args.fastspeech_step),
|
||||
os.path.join(args.checkpoint_path, "fastspeech")))
|
||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model, checkpoint_path=args.checkpoint)
|
||||
model.eval()
|
||||
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
|
@ -72,7 +90,7 @@ def synthesis(text_input, args):
|
|||
enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
|
||||
enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
|
||||
|
||||
mel_output, mel_output_postnet = model(
|
||||
_, mel_output_postnet = model(
|
||||
text,
|
||||
pos_text,
|
||||
alpha=args.alpha,
|
||||
|
@ -81,47 +99,119 @@ def synthesis(text_input, args):
|
|||
dec_non_pad_mask=None,
|
||||
dec_slf_attn_mask=None)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
num_mels=cfg['audio']['num_mels'],
|
||||
min_level_db=cfg['audio']['min_level_db'],
|
||||
ref_level_db=cfg['audio']['ref_level_db'],
|
||||
n_fft=cfg['audio']['n_fft'],
|
||||
win_length=cfg['audio']['win_length'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
power=cfg['audio']['power'],
|
||||
preemphasis=cfg['audio']['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
|
||||
np.save('mel_output', mel_output_postnet.numpy())
|
||||
result = np.exp(mel_output_postnet.numpy())
|
||||
mel_output_postnet = fluid.layers.transpose(
|
||||
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
|
||||
x = np.uint8(cm.viridis(mel_output_postnet.numpy()) * 255)
|
||||
writer.add_image('mel_0_0', x, 0, dataformats="HWC")
|
||||
ground_truth = _ljspeech_processor.load_wav(
|
||||
str('/paddle/Parakeet/dataset/LJSpeech-1.1/wavs/LJ001-0175.wav'))
|
||||
ground_truth = _ljspeech_processor.melspectrogram(ground_truth).astype(
|
||||
np.float32)
|
||||
x = np.uint8(cm.viridis(ground_truth) * 255)
|
||||
writer.add_image('mel_gt_0', x, 0, dataformats="HWC")
|
||||
wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy(
|
||||
))
|
||||
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
|
||||
mel_output_postnet = np.exp(mel_output_postnet.numpy())
|
||||
basis = librosa.filters.mel(cfg['audio']['sr'], cfg['audio']['n_fft'],
|
||||
cfg['audio']['num_mels'])
|
||||
inv_basis = np.linalg.pinv(basis)
|
||||
spec = np.maximum(1e-10, np.dot(inv_basis, mel_output_postnet))
|
||||
|
||||
# synthesis use clarinet
|
||||
wav_clarinet = synthesis_with_clarinet(
|
||||
args.config_clarinet, args.checkpoint_clarinet, result, place)
|
||||
writer.add_audio(text_input + '(clarinet)', wav_clarinet, 0,
|
||||
cfg['audio']['sr'])
|
||||
if not os.path.exists(os.path.join(args.output, 'samples')):
|
||||
os.mkdir(os.path.join(args.output, 'samples'))
|
||||
write(
|
||||
os.path.join(
|
||||
os.path.join(args.output, 'samples'), 'clarinet.wav'),
|
||||
cfg['audio']['sr'], wav_clarinet)
|
||||
|
||||
#synthesis use griffin-lim
|
||||
wav = librosa.core.griffinlim(
|
||||
spec**cfg['audio']['power'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
win_length=cfg['audio']['win_length'])
|
||||
writer.add_audio(text_input + '(griffin-lim)', wav, 0,
|
||||
cfg['audio']['sr'])
|
||||
write(
|
||||
os.path.join(
|
||||
os.path.join(args.output, 'samples'), 'grinffin-lim.wav'),
|
||||
cfg['audio']['sr'], wav)
|
||||
print("Synthesis completed !!!")
|
||||
writer.close()
|
||||
|
||||
|
||||
def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place):
|
||||
with open(config_path, 'rt') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
data_config = config["data"]
|
||||
n_mels = data_config["n_mels"]
|
||||
|
||||
teacher_config = config["teacher"]
|
||||
n_loop = teacher_config["n_loop"]
|
||||
n_layer = teacher_config["n_layer"]
|
||||
filter_size = teacher_config["filter_size"]
|
||||
|
||||
# only batch=1 for validation is enabled
|
||||
|
||||
with dg.guard(place):
|
||||
# conditioner(upsampling net)
|
||||
conditioner_config = config["conditioner"]
|
||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
||||
freeze(upsample_net)
|
||||
|
||||
residual_channels = teacher_config["residual_channels"]
|
||||
loss_type = teacher_config["loss_type"]
|
||||
output_dim = teacher_config["output_dim"]
|
||||
log_scale_min = teacher_config["log_scale_min"]
|
||||
assert loss_type == "mog" and output_dim == 3, \
|
||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
||||
|
||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
||||
n_mels, filter_size, loss_type, log_scale_min)
|
||||
# load & freeze upsample_net & teacher
|
||||
freeze(teacher)
|
||||
|
||||
student_config = config["student"]
|
||||
n_loops = student_config["n_loops"]
|
||||
n_layers = student_config["n_layers"]
|
||||
student_residual_channels = student_config["residual_channels"]
|
||||
student_filter_size = student_config["filter_size"]
|
||||
student_log_scale_min = student_config["log_scale_min"]
|
||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
||||
n_mels, student_filter_size)
|
||||
|
||||
stft_config = config["stft"]
|
||||
stft = STFT(
|
||||
n_fft=stft_config["n_fft"],
|
||||
hop_length=stft_config["hop_length"],
|
||||
win_length=stft_config["win_length"])
|
||||
|
||||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
summary(model)
|
||||
io.load_parameters(model=model, checkpoint_path=checkpoint)
|
||||
|
||||
if not os.path.exists(args.output):
|
||||
os.makedirs(args.output)
|
||||
model.eval()
|
||||
|
||||
# Rescale mel_spectrogram.
|
||||
min_level, ref_level = 1e-5, 20 # hard code it
|
||||
mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram))
|
||||
mel_spectrogram = mel_spectrogram - ref_level
|
||||
mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1)
|
||||
|
||||
mel_spectrogram = dg.to_variable(mel_spectrogram)
|
||||
mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1])
|
||||
|
||||
wav_var = model.synthesis(mel_spectrogram)
|
||||
wav_np = wav_var.numpy()[0]
|
||||
|
||||
return wav_np
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train Fastspeech model")
|
||||
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
pprint(vars(args))
|
||||
synthesis("Simple as this proposition is, it is necessary to be stated,",
|
||||
args)
|
||||
|
|
|
@ -3,10 +3,11 @@
|
|||
python -u synthesis.py \
|
||||
--use_gpu=1 \
|
||||
--alpha=1.0 \
|
||||
--checkpoint_path='checkpoint/' \
|
||||
--fastspeech_step=89000 \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/synthesis.yaml' \
|
||||
--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--config_clarine='../clarinet/configs/config.yaml' \
|
||||
--checkpoint_clarinet='../clarinet/checkpoint/step-500000' \
|
||||
--output='./synthesis' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in synthesis!"
|
||||
|
|
|
@ -17,7 +17,6 @@ import os
|
|||
import time
|
||||
import math
|
||||
from pathlib import Path
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from tqdm import tqdm
|
||||
|
@ -27,107 +26,91 @@ from tensorboardX import SummaryWriter
|
|||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||
from parakeet.models.fastspeech.utils import get_alignment
|
||||
import sys
|
||||
sys.path.append("../transformer_tts")
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(
|
||||
os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict, opti_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
parser.add_argument(
|
||||
"--alignments_path", type=str, help="path of alignments")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="experiment",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def main(args):
|
||||
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
nranks = dg.parallel.Env().nranks
|
||||
parallel = nranks > 1
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
|
||||
if args.use_data_parallel else fluid.CUDAPlace(0)
|
||||
if args.use_gpu else fluid.CPUPlace())
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'fastspeech')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
writer = SummaryWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
|
||||
with dg.guard(place):
|
||||
with fluid.unique_name.guard():
|
||||
transformer_tts = TransformerTTS(cfg)
|
||||
model_dict, _ = load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.transtts_path, "transformer"))
|
||||
transformer_tts.set_dict(model_dict)
|
||||
transformer_tts.eval()
|
||||
|
||||
model = FastSpeech(cfg)
|
||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (
|
||||
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
|
||||
learning_rate=dg.NoamDecay(1 /
|
||||
(cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
reader = LJSpeechLoader(
|
||||
cfg, args, nranks, local_rank, shuffle=True).reader()
|
||||
cfg['audio'],
|
||||
place,
|
||||
args.data,
|
||||
args.alignments_path,
|
||||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
shuffle=True).reader()
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(
|
||||
str(args.fastspeech_step),
|
||||
os.path.join(args.checkpoint_path, "fastspeech"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.fastspeech_step
|
||||
print("load checkpoint!!!")
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
|
||||
iteration=args.iteration,
|
||||
checkpoint_path=args.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(local_rank))
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
(character, mel, mel_input, pos_text, pos_mel, text_length,
|
||||
mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask,
|
||||
enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data
|
||||
|
||||
_, _, attn_probs, _, _, _ = transformer_tts(
|
||||
character,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask,
|
||||
enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask,
|
||||
dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
alignment, max_attn = get_alignment(attn_probs, mel_lens,
|
||||
cfg['transformer_head'])
|
||||
alignment = dg.to_variable(alignment).astype(np.float32)
|
||||
|
||||
if local_rank == 0 and global_step % 5 == 1:
|
||||
x = np.uint8(
|
||||
cm.viridis(max_attn[8, :mel_lens.numpy()[8]]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
0,
|
||||
dataformats="HWC")
|
||||
(character, mel, pos_text, pos_mel, enc_slf_mask,
|
||||
enc_query_mask, dec_slf_mask, dec_query_slf_mask,
|
||||
alignment) = data
|
||||
|
||||
global_step += 1
|
||||
|
||||
|
@ -161,7 +144,7 @@ def main(args):
|
|||
optimizer._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
total_loss = model.scale_loss(total_loss)
|
||||
total_loss.backward()
|
||||
model.apply_collective_grads()
|
||||
|
@ -170,17 +153,16 @@ def main(args):
|
|||
optimizer.minimize(
|
||||
total_loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'grad_clip_thresh']))
|
||||
'train']['grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,
|
||||
'fastspeech/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
||||
|
@ -190,5 +172,5 @@ if __name__ == '__main__':
|
|||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(args)
|
||||
pprint(vars(args))
|
||||
main(args)
|
||||
|
|
|
@ -1,21 +1,12 @@
|
|||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --fastspeech_step
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python -u train.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
--lr=0.001 \
|
||||
--save_step=500 \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path='../../dataset/LJSpeech-1.1' \
|
||||
--transtts_path='../transformer_tts/checkpoint' \
|
||||
--transformer_step=120000 \
|
||||
--save_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/fastspeech.yaml' \
|
||||
#--checkpoint_path='./checkpoint' \
|
||||
#--fastspeech_step=97000 \
|
||||
--data='../../dataset/LJSpeech-1.1' \
|
||||
--alignments_path='./alignments/alignments.txt' \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
#--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
|
|
|
@ -27,6 +27,16 @@ The model adopts the multi-head attention mechanism to replace the RNN structure
|
|||
├── train_transformer.py # script for transformer model training
|
||||
├── train_vocoder.py # script for vocoder model training
|
||||
```
|
||||
## Saving & Loading
|
||||
`train_transformer.py` and `train_vocoer.py` have 3 arguments in common, `--checkpooint`, `iteration` and `output`.
|
||||
|
||||
1. `output` is the directory for saving results.
|
||||
During training, checkpoints are saved in `checkpoints/` in `output` and tensorboard log is save in `log/` in `output`.
|
||||
During synthesis, results are saved in `samples/` in `output` and tensorboard log is save in `log/` in `output`.
|
||||
|
||||
2. `--checkpoint` and `--iteration` for loading from existing checkpoint. Loading existing checkpoiont follows the following rule:
|
||||
If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||
If `--checkpoint` is not provided, we try to load the model specified by `--iteration` from the checkpoint directory. If `--iteration` is not provided, we try to load the latested checkpoint from checkpoint directory.
|
||||
|
||||
## Train Transformer
|
||||
|
||||
|
@ -34,26 +44,26 @@ TransformerTTS model can be trained with ``train_transformer.py``.
|
|||
```bash
|
||||
python train_trasformer.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_transformer.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
Or you can run the script file directly.
|
||||
```bash
|
||||
sh train_transformer.sh
|
||||
```
|
||||
If you want to train on multiple GPUs, you must set ``--use_data_parallel=1``, and then start training as follows:
|
||||
If you want to train on multiple GPUs, you must start training as follows:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_transformer.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=1 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_transformer.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--transformer_step``.
|
||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||
|
||||
**Note: In order to ensure the training effect, we recommend using multi-GPU training to enlarge the batch size, and at least 16 samples in single batch per GPU.**
|
||||
|
||||
|
@ -65,25 +75,25 @@ Vocoder model can be trained with ``train_vocoder.py``.
|
|||
```bash
|
||||
python train_vocoder.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_vocoder.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./vocoder' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
Or you can run the script file directly.
|
||||
```bash
|
||||
sh train_vocoder.sh
|
||||
```
|
||||
If you want to train on multiple GPUs, you must set ``--use_data_parallel=1``, and then start training as follows:
|
||||
If you want to train on multiple GPUs, you must start training as follows:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_vocoder.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=1 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_vocoder.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./vocoder' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--vocoder_step``.
|
||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||
|
||||
For more help on arguments:
|
||||
``python train_vocoder.py --help``.
|
||||
|
@ -92,13 +102,12 @@ For more help on arguments:
|
|||
After training the TransformerTTS and vocoder model, audio can be synthesized with ``synthesis.py``.
|
||||
```bash
|
||||
python synthesis.py \
|
||||
--max_len=50 \
|
||||
--transformer_step=160000 \
|
||||
--vocoder_step=70000 \
|
||||
--use_gpu=1
|
||||
--checkpoint_path='./checkpoint' \
|
||||
--sample_path='./sample' \
|
||||
--config_path='config/synthesis.yaml' \
|
||||
--max_len=300 \
|
||||
--use_gpu=1 \
|
||||
--output='./synthesis' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--checkpoint_transformer='./checkpoint/transformer/step-120000' \
|
||||
--checkpoint_vocoder='./checkpoint/vocoder/step-100000' \
|
||||
```
|
||||
|
||||
Or you can run the script file directly.
|
||||
|
@ -106,7 +115,5 @@ Or you can run the script file directly.
|
|||
sh synthesis.sh
|
||||
```
|
||||
|
||||
And the audio file will be saved in ``--sample_path``.
|
||||
|
||||
For more help on arguments:
|
||||
``python synthesis.py --help``.
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 256 #275
|
||||
win_length: 1024 #1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
|
||||
network:
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
encoder_num_head: 4
|
||||
encoder_n_layers: 3
|
||||
decoder_num_head: 4
|
||||
decoder_n_layers: 3
|
||||
outputs_per_step: 1
|
||||
stop_token: False
|
||||
|
||||
vocoder:
|
||||
hidden_size: 256
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
learning_rate: 0.001
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
||||
|
||||
checkpoint_interval: 1000
|
||||
image_interval: 2000
|
||||
|
||||
max_epochs: 10000
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
|
@ -1,20 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
||||
|
||||
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
|
@ -30,14 +30,15 @@ from parakeet.models.transformer_tts.utils import *
|
|||
class LJSpeechLoader:
|
||||
def __init__(self,
|
||||
config,
|
||||
args,
|
||||
place,
|
||||
data_path,
|
||||
batch_size,
|
||||
nranks,
|
||||
rank,
|
||||
is_vocoder=False,
|
||||
shuffle=True):
|
||||
place = fluid.CUDAPlace(rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
LJSPEECH_ROOT = Path(args.data_path)
|
||||
LJSPEECH_ROOT = Path(data_path)
|
||||
metadata = LJSpeechMetaData(LJSPEECH_ROOT)
|
||||
transformer = LJSpeech(config)
|
||||
dataset = TransformDataset(metadata, transformer)
|
||||
|
@ -46,8 +47,8 @@ class LJSpeechLoader:
|
|||
sampler = DistributedSampler(
|
||||
len(dataset), nranks, rank, shuffle=shuffle)
|
||||
|
||||
assert args.batch_size % nranks == 0
|
||||
each_bs = args.batch_size // nranks
|
||||
assert batch_size % nranks == 0
|
||||
each_bs = batch_size // nranks
|
||||
if is_vocoder:
|
||||
dataloader = DataCargo(
|
||||
dataset,
|
||||
|
@ -98,15 +99,15 @@ class LJSpeech(object):
|
|||
super(LJSpeech, self).__init__()
|
||||
self.config = config
|
||||
self._ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=config['audio']['sr'],
|
||||
num_mels=config['audio']['num_mels'],
|
||||
min_level_db=config['audio']['min_level_db'],
|
||||
ref_level_db=config['audio']['ref_level_db'],
|
||||
n_fft=config['audio']['n_fft'],
|
||||
win_length=config['audio']['win_length'],
|
||||
hop_length=config['audio']['hop_length'],
|
||||
power=config['audio']['power'],
|
||||
preemphasis=config['audio']['preemphasis'],
|
||||
sample_rate=config['sr'],
|
||||
num_mels=config['num_mels'],
|
||||
min_level_db=config['min_level_db'],
|
||||
ref_level_db=config['ref_level_db'],
|
||||
n_fft=config['n_fft'],
|
||||
win_length=config['win_length'],
|
||||
hop_length=config['hop_length'],
|
||||
power=config['power'],
|
||||
preemphasis=config['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
default='configs/train_transformer.yaml',
|
||||
help="the yaml config file path.")
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, default=32, help="batch size for training.")
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10000,
|
||||
help="the number of epoch for training.")
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="the learning rate for training.")
|
||||
parser.add_argument(
|
||||
'--save_step',
|
||||
type=int,
|
||||
default=500,
|
||||
help="checkpointing interval during training.")
|
||||
parser.add_argument(
|
||||
'--image_step',
|
||||
type=int,
|
||||
default=2000,
|
||||
help="attention image interval during training.")
|
||||
parser.add_argument(
|
||||
'--max_len',
|
||||
type=int,
|
||||
default=400,
|
||||
help="The max length of audio when synthsis.")
|
||||
parser.add_argument(
|
||||
'--transformer_step',
|
||||
type=int,
|
||||
default=160000,
|
||||
help="Global step to restore checkpoint of transformer.")
|
||||
parser.add_argument(
|
||||
'--vocoder_step',
|
||||
type=int,
|
||||
default=90000,
|
||||
help="Global step to restore checkpoint of postnet.")
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=int,
|
||||
default=1,
|
||||
help="use gpu or not during training.")
|
||||
parser.add_argument(
|
||||
'--use_data_parallel',
|
||||
type=int,
|
||||
default=0,
|
||||
help="use data parallel or not during training.")
|
||||
parser.add_argument(
|
||||
'--stop_token',
|
||||
type=int,
|
||||
default=0,
|
||||
help="use stop token loss in network or not.")
|
||||
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='./dataset/LJSpeech-1.1',
|
||||
help="the path of dataset.")
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="the path to load checkpoint or pretrain model.")
|
||||
parser.add_argument(
|
||||
'--save_path',
|
||||
type=str,
|
||||
default='./checkpoint',
|
||||
help="the path to save checkpoint.")
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='./log',
|
||||
help="the directory to save tensorboard log.")
|
||||
parser.add_argument(
|
||||
'--sample_path',
|
||||
type=str,
|
||||
default='./sample',
|
||||
help="the directory to save audio sample in synthesis.")
|
|
@ -13,64 +13,84 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from scipy.io.wavfile import write
|
||||
from parakeet.g2p.en import text_to_sequence
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from matplotlib import cm
|
||||
from tensorboardX import SummaryWriter
|
||||
from ruamel import yaml
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from collections import OrderedDict
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
from parakeet.g2p.en import text_to_sequence
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet import audio
|
||||
from parakeet.models.transformer_tts.vocoder import Vocoder
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
from parakeet.models.transformer_tts import Vocoder
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=200,
|
||||
help="The max length of audio when synthsis.")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_transformer",
|
||||
type=str,
|
||||
help="transformer_tts checkpoint to synthesis")
|
||||
parser.add_argument(
|
||||
"--checkpoint_vocoder",
|
||||
type=str,
|
||||
help="vocoder checkpoint to synthesis")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="synthesis",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def synthesis(text_input, args):
|
||||
place = (fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
# tensorboard
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'synthesis')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path)
|
||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
||||
|
||||
with dg.guard(place):
|
||||
with fluid.unique_name.guard():
|
||||
model = TransformerTTS(cfg)
|
||||
model.set_dict(
|
||||
load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.checkpoint_path, "transformer")))
|
||||
network_cfg = cfg['network']
|
||||
model = TransformerTTS(
|
||||
network_cfg['embedding_size'], network_cfg['hidden_size'],
|
||||
network_cfg['encoder_num_head'],
|
||||
network_cfg['encoder_n_layers'], cfg['audio']['num_mels'],
|
||||
network_cfg['outputs_per_step'],
|
||||
network_cfg['decoder_num_head'],
|
||||
network_cfg['decoder_n_layers'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model, checkpoint_path=args.checkpoint_transformer)
|
||||
model.eval()
|
||||
|
||||
with fluid.unique_name.guard():
|
||||
model_vocoder = Vocoder(cfg, args.batch_size)
|
||||
model_vocoder.set_dict(
|
||||
load_checkpoint(
|
||||
str(args.vocoder_step),
|
||||
os.path.join(args.checkpoint_path, "vocoder")))
|
||||
model_vocoder = Vocoder(
|
||||
cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
||||
cfg['audio']['num_mels'], cfg['audio']['n_fft'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model_vocoder, checkpoint_path=args.checkpoint_vocoder)
|
||||
model_vocoder.eval()
|
||||
# init input
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
|
@ -83,6 +103,7 @@ def synthesis(text_input, args):
|
|||
for i in pbar:
|
||||
dec_slf_mask = get_triu_tensor(
|
||||
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
||||
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
|
||||
dec_slf_mask = fluid.layers.cast(
|
||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
|
@ -114,6 +135,7 @@ def synthesis(text_input, args):
|
|||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
|
||||
# synthesis with cbhg
|
||||
wav = _ljspeech_processor.inv_spectrogram(
|
||||
fluid.layers.transpose(
|
||||
fluid.layers.squeeze(mag_pred, [0]), [1, 0]).numpy())
|
||||
|
@ -127,29 +149,24 @@ def synthesis(text_input, args):
|
|||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
writer.add_audio(text_input + '(cbhg)', wav, 0, cfg['audio']['sr'])
|
||||
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
|
||||
if not os.path.exists(args.sample_path):
|
||||
os.mkdir(args.sample_path)
|
||||
if not os.path.exists(os.path.join(args.output, 'samples')):
|
||||
os.mkdir(os.path.join(args.output, 'samples'))
|
||||
write(
|
||||
os.path.join(args.sample_path, 'test.wav'), cfg['audio']['sr'],
|
||||
wav)
|
||||
os.path.join(os.path.join(args.output, 'samples'), 'cbhg.wav'),
|
||||
cfg['audio']['sr'], wav)
|
||||
|
||||
# synthesis with griffin-lim
|
||||
wav = _ljspeech_processor.inv_melspectrogram(
|
||||
fluid.layers.transpose(
|
||||
fluid.layers.squeeze(postnet_pred, [0]), [1, 0]).numpy())
|
||||
writer.add_audio(text_input + '(griffin)', wav, 0, cfg['audio']['sr'])
|
||||
|
||||
write(
|
||||
os.path.join(os.path.join(args.output, 'samples'), 'griffin.wav'),
|
||||
cfg['audio']['sr'], wav)
|
||||
print("Synthesis completed !!!")
|
||||
writer.close()
|
||||
|
||||
|
||||
|
@ -157,5 +174,7 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(vars(args))
|
||||
synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.",
|
||||
args)
|
||||
|
|
|
@ -3,13 +3,11 @@
|
|||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u synthesis.py \
|
||||
--max_len=300 \
|
||||
--transformer_step=120000 \
|
||||
--vocoder_step=100000 \
|
||||
--use_gpu=1 \
|
||||
--checkpoint_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--sample_path='./sample' \
|
||||
--config_path='configs/synthesis.yaml' \
|
||||
--output='./synthesis' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--checkpoint_transformer='./checkpoint/transformer/step-120000' \
|
||||
--checkpoint_vocoder='./checkpoint/vocoder/step-100000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
|
|
|
@ -16,7 +16,6 @@ from tqdm import tqdm
|
|||
from tensorboardX import SummaryWriter
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from matplotlib import cm
|
||||
|
@ -26,65 +25,85 @@ import paddle.fluid.dygraph as dg
|
|||
import paddle.fluid.layers as layers
|
||||
from parakeet.models.transformer_tts.utils import cross_entropy
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(
|
||||
os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict, opti_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="experiment",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def main(args):
|
||||
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
nranks = dg.parallel.Env().nranks
|
||||
parallel = nranks > 1
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
|
||||
if args.use_data_parallel else fluid.CUDAPlace(0)
|
||||
if args.use_gpu else fluid.CPUPlace())
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'transformer')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
writer = SummaryWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
|
||||
with dg.guard(place):
|
||||
model = TransformerTTS(cfg)
|
||||
network_cfg = cfg['network']
|
||||
model = TransformerTTS(
|
||||
network_cfg['embedding_size'], network_cfg['hidden_size'],
|
||||
network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
|
||||
cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
|
||||
network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
|
||||
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (
|
||||
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
|
||||
learning_rate=dg.NoamDecay(1 /
|
||||
(cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.checkpoint_path, "transformer"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.transformer_step
|
||||
print("load checkpoint!!!")
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
|
||||
iteration=args.iteration,
|
||||
checkpoint_path=args.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(local_rank))
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
reader = LJSpeechLoader(
|
||||
cfg, args, nranks, local_rank, shuffle=True).reader()
|
||||
cfg['audio'],
|
||||
place,
|
||||
args.data,
|
||||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
shuffle=True).reader()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
|
@ -111,7 +130,7 @@ def main(args):
|
|||
loss = mel_loss + post_mel_loss
|
||||
|
||||
# Note: When used stop token loss the learning did not work.
|
||||
if args.stop_token:
|
||||
if cfg['network']['stop_token']:
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
stop_loss = cross_entropy(stop_preds, label)
|
||||
loss = loss + stop_loss
|
||||
|
@ -122,11 +141,11 @@ def main(args):
|
|||
'post_mel_loss': post_mel_loss.numpy()
|
||||
}, global_step)
|
||||
|
||||
if args.stop_token:
|
||||
if cfg['network']['stop_token']:
|
||||
writer.add_scalar('stop_loss',
|
||||
stop_loss.numpy(), global_step)
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha':
|
||||
model._layers.encoder.alpha.numpy(),
|
||||
|
@ -143,12 +162,12 @@ def main(args):
|
|||
optimizer._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
|
||||
if global_step % args.image_step == 1:
|
||||
if global_step % cfg['train']['image_interval'] == 1:
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(4):
|
||||
for j in range(cfg['network']['decoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * args.batch_size
|
||||
// 2]) * 255)
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
|
@ -156,10 +175,10 @@ def main(args):
|
|||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(4):
|
||||
for j in range(cfg['network']['encoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * args.batch_size
|
||||
// 2]) * 255)
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
|
@ -167,17 +186,17 @@ def main(args):
|
|||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(4):
|
||||
for j in range(cfg['network']['decoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * args.batch_size
|
||||
// 2]) * 255)
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
|
@ -186,17 +205,16 @@ def main(args):
|
|||
optimizer.minimize(
|
||||
loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'grad_clip_thresh']))
|
||||
'train']['grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,
|
||||
'transformer/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
||||
|
@ -204,8 +222,7 @@ def main(args):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train TransformerTTS model")
|
||||
add_config_options_to_parser(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(args)
|
||||
pprint(vars(args))
|
||||
main(args)
|
||||
|
|
|
@ -1,22 +1,12 @@
|
|||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --transformer_step
|
||||
export CUDA_VISIBLE_DEVICES=2
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python -u train_transformer.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
--lr=0.001 \
|
||||
--save_step=1000 \
|
||||
--image_step=2000 \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--stop_token=0 \
|
||||
--data_path='../../dataset/LJSpeech-1.1' \
|
||||
--save_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/train_transformer.yaml' \
|
||||
#--checkpoint_path='./checkpoint' \
|
||||
#--transformer_step=160000 \
|
||||
--data='../../dataset/LJSpeech-1.1' \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
#--checkpoint='./checkpoint/transformer/step-120000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
|
|
|
@ -18,71 +18,87 @@ from pathlib import Path
|
|||
from collections import OrderedDict
|
||||
import argparse
|
||||
from ruamel import yaml
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.models.transformer_tts.vocoder import Vocoder
|
||||
from parakeet.models.transformer_tts import Vocoder
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = dg.load_dygraph(os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict, opti_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="vocoder",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def main(args):
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
nranks = dg.parallel.Env().nranks
|
||||
parallel = nranks > 1
|
||||
|
||||
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
|
||||
if args.use_data_parallel else fluid.CUDAPlace(0)
|
||||
if args.use_gpu else fluid.CPUPlace())
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'vocoder')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
writer = SummaryWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
|
||||
with dg.guard(place):
|
||||
model = Vocoder(cfg, args.batch_size)
|
||||
model = Vocoder(cfg['train']['batch_size'],
|
||||
cfg['vocoder']['hidden_size'],
|
||||
cfg['audio']['num_mels'], cfg['audio']['n_fft'])
|
||||
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (
|
||||
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
|
||||
learning_rate=dg.NoamDecay(1 /
|
||||
(cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(
|
||||
str(args.vocoder_step),
|
||||
os.path.join(args.checkpoint_path, "vocoder"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.vocoder_step
|
||||
print("load checkpoint!!!")
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
|
||||
iteration=args.iteration,
|
||||
checkpoint_path=args.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(local_rank))
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
reader = LJSpeechLoader(
|
||||
cfg, args, nranks, local_rank, is_vocoder=True).reader()
|
||||
cfg['audio'],
|
||||
place,
|
||||
args.data,
|
||||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
is_vocoder=True).reader()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
|
@ -95,7 +111,7 @@ def main(args):
|
|||
loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(mag_pred, mag)))
|
||||
|
||||
if args.use_data_parallel:
|
||||
if parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
|
@ -104,7 +120,7 @@ def main(args):
|
|||
optimizer.minimize(
|
||||
loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'grad_clip_thresh']))
|
||||
'train']['grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
|
||||
if local_rank == 0:
|
||||
|
@ -112,13 +128,12 @@ def main(args):
|
|||
'loss': loss.numpy(),
|
||||
}, global_step)
|
||||
|
||||
if global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,
|
||||
'vocoder/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
|
|
@ -1,20 +1,12 @@
|
|||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --vocoder_step
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u train_vocoder.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
--lr=0.001 \
|
||||
--save_step=1000 \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path='../../dataset/LJSpeech-1.1' \
|
||||
--save_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/train_vocoder.yaml' \
|
||||
#--checkpoint_path='./checkpoint' \
|
||||
#--vocoder_step=27000 \
|
||||
--data='../../dataset/LJSpeech-1.1' \
|
||||
--output='./vocoder' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
#--checkpoint='./checkpoint/vocoder/step-100000' \
|
||||
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
|
|
|
@ -75,10 +75,10 @@ class Decoder(dg.Layer):
|
|||
Compute decoder outputs.
|
||||
|
||||
Args:
|
||||
enc_seq (Variable): shape(B, T_text, C), dtype float32,
|
||||
the output of length regulator, where T_text means the timesteps of input text,
|
||||
enc_seq (Variable): shape(B, T_mel, C), dtype float32,
|
||||
the output of length regulator, where T_mel means the timesteps of input spectrum.
|
||||
enc_pos (Variable): shape(B, T_mel), dtype int64,
|
||||
the spectrum position, where T_mel means the timesteps of input spectrum,
|
||||
the spectrum position.
|
||||
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
|
|
|
@ -24,11 +24,13 @@ from parakeet.models.fastspeech.decoder import Decoder
|
|||
|
||||
|
||||
class FastSpeech(dg.Layer):
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, num_mels=80):
|
||||
"""FastSpeech model.
|
||||
|
||||
Args:
|
||||
cfg: the yaml configs used in FastSpeech model.
|
||||
num_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
|
||||
"""
|
||||
super(FastSpeech, self).__init__()
|
||||
|
||||
|
@ -37,15 +39,15 @@ class FastSpeech(dg.Layer):
|
|||
len_max_seq=cfg['max_seq_len'],
|
||||
n_layers=cfg['encoder_n_layer'],
|
||||
n_head=cfg['encoder_head'],
|
||||
d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_q=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_model=cfg['fs_hidden_size'],
|
||||
d_k=cfg['hidden_size'] // cfg['encoder_head'],
|
||||
d_q=cfg['hidden_size'] // cfg['encoder_head'],
|
||||
d_model=cfg['hidden_size'],
|
||||
d_inner=cfg['encoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
fft_conv1d_padding=cfg['fft_conv1d_padding'],
|
||||
dropout=0.1)
|
||||
self.length_regulator = LengthRegulator(
|
||||
input_size=cfg['fs_hidden_size'],
|
||||
input_size=cfg['hidden_size'],
|
||||
out_channels=cfg['duration_predictor_output_size'],
|
||||
filter_size=cfg['duration_predictor_filter_size'],
|
||||
dropout=cfg['dropout'])
|
||||
|
@ -53,30 +55,30 @@ class FastSpeech(dg.Layer):
|
|||
len_max_seq=cfg['max_seq_len'],
|
||||
n_layers=cfg['decoder_n_layer'],
|
||||
n_head=cfg['decoder_head'],
|
||||
d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_q=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_model=cfg['fs_hidden_size'],
|
||||
d_k=cfg['hidden_size'] // cfg['decoder_head'],
|
||||
d_q=cfg['hidden_size'] // cfg['decoder_head'],
|
||||
d_model=cfg['hidden_size'],
|
||||
d_inner=cfg['decoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
fft_conv1d_padding=cfg['fft_conv1d_padding'],
|
||||
dropout=0.1)
|
||||
self.weight = fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer())
|
||||
k = math.sqrt(1.0 / cfg['fs_hidden_size'])
|
||||
k = math.sqrt(1.0 / cfg['hidden_size'])
|
||||
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k))
|
||||
self.mel_linear = dg.Linear(
|
||||
cfg['fs_hidden_size'],
|
||||
cfg['audio']['num_mels'] * cfg['audio']['outputs_per_step'],
|
||||
cfg['hidden_size'],
|
||||
num_mels * cfg['outputs_per_step'],
|
||||
param_attr=self.weight,
|
||||
bias_attr=self.bias, )
|
||||
self.postnet = PostConvNet(
|
||||
n_mels=cfg['audio']['num_mels'],
|
||||
n_mels=num_mels,
|
||||
num_hidden=512,
|
||||
filter_size=5,
|
||||
padding=int(5 / 2),
|
||||
num_conv=5,
|
||||
outputs_per_step=cfg['audio']['outputs_per_step'],
|
||||
outputs_per_step=cfg['outputs_per_step'],
|
||||
use_cudnn=True,
|
||||
dropout=0.1,
|
||||
batchnorm_last=True)
|
||||
|
@ -144,7 +146,7 @@ class FastSpeech(dg.Layer):
|
|||
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
|
||||
slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
|
||||
slf_attn_mask = fluid.layers.cast(
|
||||
dg.to_variable(slf_attn_mask == 0), np.float32)
|
||||
dg.to_variable(slf_attn_mask != 0), np.float32) * (-2**32 + 1)
|
||||
slf_attn_mask = dg.to_variable(slf_attn_mask)
|
||||
dec_non_pad_mask = fluid.layers.unsqueeze(
|
||||
(decoder_pos != 0).astype(np.float32), [-1])
|
||||
|
|
|
@ -37,11 +37,10 @@ def score_F(attn):
|
|||
|
||||
|
||||
def compute_duration(attn, mel_lens):
|
||||
alignment = np.zeros([attn.shape[0], attn.shape[2]])
|
||||
mel_lens = mel_lens.numpy()
|
||||
for i in range(attn.shape[0]):
|
||||
for j in range(mel_lens[i]):
|
||||
max_index = np.argmax(attn[i, j])
|
||||
alignment[i, max_index] += 1
|
||||
alignment = np.zeros([attn.shape[2]])
|
||||
#for i in range(attn.shape[0]):
|
||||
for j in range(mel_lens):
|
||||
max_index = np.argmax(attn[0, j])
|
||||
alignment[max_index] += 1
|
||||
|
||||
return alignment
|
||||
|
|
|
@ -10,4 +10,6 @@
|
|||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# limitations under the License.
|
||||
from .transformer_tts import TransformerTTS
|
||||
from .vocoder import Vocoder
|
|
@ -22,14 +22,20 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
|||
|
||||
|
||||
class Decoder(dg.Layer):
|
||||
def __init__(self, num_hidden, config, num_head=4, n_layers=3):
|
||||
def __init__(self,
|
||||
num_hidden,
|
||||
num_mels=80,
|
||||
outputs_per_step=1,
|
||||
num_head=4,
|
||||
n_layers=3):
|
||||
"""Decoder layer of TransformerTTS.
|
||||
|
||||
Args:
|
||||
num_hidden (int): the number of source vocabulary.
|
||||
config: the yaml configs used in decoder.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 3.
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 3.
|
||||
"""
|
||||
super(Decoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
|
@ -51,7 +57,7 @@ class Decoder(dg.Layer):
|
|||
self.pos_inp),
|
||||
trainable=False))
|
||||
self.decoder_prenet = PreNet(
|
||||
input_size=config['audio']['num_mels'],
|
||||
input_size=num_mels,
|
||||
hidden_size=num_hidden * 2,
|
||||
output_size=num_hidden,
|
||||
dropout_rate=0.2)
|
||||
|
@ -85,7 +91,7 @@ class Decoder(dg.Layer):
|
|||
self.add_sublayer("ffns_{}".format(i), layer)
|
||||
self.mel_linear = dg.Linear(
|
||||
num_hidden,
|
||||
config['audio']['num_mels'] * config['audio']['outputs_per_step'],
|
||||
num_mels * outputs_per_step,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer()),
|
||||
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
|
@ -99,12 +105,12 @@ class Decoder(dg.Layer):
|
|||
low=-k, high=k)))
|
||||
|
||||
self.postconvnet = PostConvNet(
|
||||
config['audio']['num_mels'],
|
||||
config['hidden_size'],
|
||||
num_mels,
|
||||
num_hidden,
|
||||
filter_size=5,
|
||||
padding=4,
|
||||
num_conv=5,
|
||||
outputs_per_step=config['audio']['outputs_per_step'],
|
||||
outputs_per_step=outputs_per_step,
|
||||
use_cudnn=True)
|
||||
|
||||
def forward(self,
|
||||
|
|
|
@ -26,8 +26,8 @@ class Encoder(dg.Layer):
|
|||
Args:
|
||||
embedding_size (int): the size of position embedding.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 3.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 3.
|
||||
"""
|
||||
super(Encoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
|
|
|
@ -18,16 +18,32 @@ from parakeet.models.transformer_tts.decoder import Decoder
|
|||
|
||||
|
||||
class TransformerTTS(dg.Layer):
|
||||
def __init__(self, config):
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
num_hidden,
|
||||
encoder_num_head=4,
|
||||
encoder_n_layers=3,
|
||||
n_mels=80,
|
||||
outputs_per_step=1,
|
||||
decoder_num_head=4,
|
||||
decoder_n_layers=3):
|
||||
"""TransformerTTS model.
|
||||
|
||||
Args:
|
||||
config: the yaml configs used in TransformerTTS model.
|
||||
embedding_size (int): the size of position embedding.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
encoder_num_head (int, optional): the head number of multihead attention in encoder. Defaults to 4.
|
||||
encoder_n_layers (int, optional): the layers number of multihead attention in encoder. Defaults to 3.
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
|
||||
decoder_num_head (int, optional): the head number of multihead attention in decoder. Defaults to 4.
|
||||
decoder_n_layers (int, optional): the layers number of multihead attention in decoder. Defaults to 3.
|
||||
"""
|
||||
super(TransformerTTS, self).__init__()
|
||||
self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
|
||||
self.decoder = Decoder(config['hidden_size'], config)
|
||||
self.config = config
|
||||
self.encoder = Encoder(embedding_size, num_hidden, encoder_num_head,
|
||||
encoder_n_layers)
|
||||
self.decoder = Decoder(num_hidden, n_mels, outputs_per_step,
|
||||
decoder_num_head, decoder_n_layers)
|
||||
|
||||
def forward(self,
|
||||
characters,
|
||||
|
|
|
@ -19,22 +19,22 @@ from parakeet.models.transformer_tts.cbhg import CBHG
|
|||
|
||||
|
||||
class Vocoder(dg.Layer):
|
||||
def __init__(self, config, batch_size):
|
||||
def __init__(self, batch_size, hidden_size, num_mels=80, n_fft=2048):
|
||||
"""CBHG Network (mel -> linear)
|
||||
|
||||
Args:
|
||||
config: the yaml configs used in Vocoder model.
|
||||
batch_size (int): the batch size of input.
|
||||
hidden_size (int): the size of hidden layer in network.
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
n_fft (int, optional): length of the windowed signal after padding with zeros. Defaults to 2048.
|
||||
"""
|
||||
super(Vocoder, self).__init__()
|
||||
self.pre_proj = Conv1D(
|
||||
num_channels=config['audio']['num_mels'],
|
||||
num_filters=config['hidden_size'],
|
||||
filter_size=1)
|
||||
self.cbhg = CBHG(config['hidden_size'], batch_size)
|
||||
num_channels=num_mels, num_filters=hidden_size, filter_size=1)
|
||||
self.cbhg = CBHG(hidden_size, batch_size)
|
||||
self.post_proj = Conv1D(
|
||||
num_channels=config['hidden_size'],
|
||||
num_filters=(config['audio']['n_fft'] // 2) + 1,
|
||||
num_channels=hidden_size,
|
||||
num_filters=(n_fft // 2) + 1,
|
||||
filter_size=1)
|
||||
|
||||
def forward(self, mel):
|
||||
|
|
|
@ -125,13 +125,20 @@ def load_parameters(model,
|
|||
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
dict_new = {}
|
||||
|
||||
# cast to desired data type, for mixed-precision training/inference.
|
||||
for k, v in model_dict.items():
|
||||
if k in state_dict and convert_np_dtype(v.dtype) != state_dict[
|
||||
k].dtype:
|
||||
model_dict[k] = v.astype(state_dict[k].numpy().dtype)
|
||||
|
||||
model.set_dict(model_dict)
|
||||
if k.startswith('_layers.'):
|
||||
k = k[8:]
|
||||
|
||||
if k in state_dict:
|
||||
if convert_np_dtype(v.dtype) != state_dict[k].dtype:
|
||||
v = v.astype(state_dict[k].numpy().dtype)
|
||||
dict_new[k] = v
|
||||
|
||||
model.set_dict(dict_new)
|
||||
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
|
||||
local_rank, checkpoint_path))
|
||||
|
||||
|
|
Loading…
Reference in New Issue