Merge branch 'master' into 'master'
completed fastspeech and modified save/load See merge request !50
This commit is contained in:
commit
72e51b0f64
|
@ -0,0 +1,52 @@
|
|||
data:
|
||||
batch_size: 8
|
||||
train_clip_seconds: 0.5
|
||||
sample_rate: 22050
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
n_fft: 2048
|
||||
|
||||
n_mels: 80
|
||||
valid_size: 16
|
||||
|
||||
|
||||
conditioner:
|
||||
upsampling_factors: [16, 16]
|
||||
|
||||
teacher:
|
||||
n_loop: 10
|
||||
n_layer: 3
|
||||
filter_size: 2
|
||||
residual_channels: 128
|
||||
loss_type: "mog"
|
||||
output_dim: 3
|
||||
log_scale_min: -9
|
||||
|
||||
student:
|
||||
n_loops: [10, 10, 10, 10, 10, 10]
|
||||
n_layers: [1, 1, 1, 1, 1, 1]
|
||||
filter_size: 3
|
||||
residual_channels: 64
|
||||
log_scale_min: -7
|
||||
|
||||
stft:
|
||||
n_fft: 2048
|
||||
win_length: 1024
|
||||
hop_length: 256
|
||||
|
||||
loss:
|
||||
lmd: 4
|
||||
|
||||
train:
|
||||
learning_rate: 0.0005
|
||||
anneal_rate: 0.5
|
||||
anneal_interval: 200000
|
||||
gradient_max_norm: 100.0
|
||||
|
||||
checkpoint_interval: 1000
|
||||
eval_interval: 1000
|
||||
|
||||
max_iterations: 2000000
|
||||
|
||||
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
# Fastspeech
|
||||
|
||||
PaddlePaddle dynamic graph implementation of Fastspeech, a feed-forward network based on Transformer. The implementation is based on [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263).
|
||||
|
||||
## Dataset
|
||||
|
@ -20,60 +21,123 @@ mel-spectrogram sequence for parallel mel-spectrogram generation. We use the Tra
|
|||
The model consists of encoder, decoder and length regulator three parts.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
├── config # yaml configuration files
|
||||
├── synthesis.py # script to synthesize waveform from text
|
||||
├── train.py # script for model training
|
||||
```
|
||||
|
||||
## Train Transformer
|
||||
## Saving & Loading
|
||||
|
||||
`train_transformer.py` and `train_vocoer.py` have 3 arguments in common, `--checkpoint`, `--iteration` and `--output`.
|
||||
|
||||
1. `--output` is the directory for saving results.
|
||||
During training, checkpoints are saved in `${output}/checkpoints` and tensorboard logs are saved in `${output}/log`.
|
||||
During synthesis, results are saved in `${output}/samples` and tensorboard log is save in `${output}/log`.
|
||||
|
||||
2. `--checkpoint` is the path of a checkpoint and `--iteration` is the target step. They are used to load checkpoints in the following way.
|
||||
|
||||
- If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||
|
||||
- If `--checkpoint` is not provided, we try to load the checkpoint of the target step specified by `--iteration` from the `${output}/checkpoints/` directory, e.g. if given `--iteration 120000`, the checkpoint `${output}/checkpoints/step-120000.*` will be load.
|
||||
|
||||
- If both `--checkpoint` and `--iteration` are not provided, we try to load the latest checkpoint from `${output}/checkpoints/` directory.
|
||||
|
||||
## Compute Phoneme Duration
|
||||
|
||||
A ground truth duration of each phoneme (number of frames in the spectrogram that correspond to that phoneme) should be provided when training a FastSpeech model.
|
||||
|
||||
We compute the ground truth duration of each phomemes in the following way.
|
||||
We extract the encoder-decoder attention alignment from a trained Transformer TTS model;
|
||||
Each frame is considered corresponding to the phoneme that receive the most attention;
|
||||
|
||||
You can run alignments/get_alignments.py to get it.
|
||||
|
||||
```bash
|
||||
cd alignments
|
||||
python get_alignments.py \
|
||||
--use_gpu=1 \
|
||||
--output='./alignments' \
|
||||
--data=${DATAPATH} \
|
||||
--config=${CONFIG} \
|
||||
--checkpoint_transformer=${CHECKPOINT} \
|
||||
```
|
||||
|
||||
where `${DATAPATH}` is the path saved LJSpeech data, `${CHECKPOINT}` is the pretrain model path of TransformerTTS, `${CONFIG}` is the config yaml file of TransformerTTS checkpoint. It is necessary for you to prepare a pre-trained TranformerTTS checkpoint.
|
||||
|
||||
For more help on arguments
|
||||
|
||||
``python alignments.py --help``.
|
||||
|
||||
Or you can use your own phoneme duration, you just need to process the data into the following format.
|
||||
|
||||
```bash
|
||||
{'fname1': alignment1,
|
||||
'fname2': alignment2,
|
||||
...}
|
||||
```
|
||||
|
||||
## Train FastSpeech
|
||||
|
||||
FastSpeech model can be trained by running ``train.py``.
|
||||
|
||||
FastSpeech model can be trained with ``train.py``.
|
||||
```bash
|
||||
python train.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path=${DATAPATH} \
|
||||
--transtts_path='../transformer_tts/checkpoint' \
|
||||
--transformer_step=160000 \
|
||||
--config_path='config/fastspeech.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--alignments_path=${ALIGNMENTS_PATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
Or you can run the script file directly.
|
||||
|
||||
```bash
|
||||
sh train.sh
|
||||
```
|
||||
If you want to train on multiple GPUs, you must set ``--use_data_parallel=1``, and then start training as follows:
|
||||
|
||||
If you want to train on multiple GPUs, start training in the following way.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=1 \
|
||||
--data_path=${DATAPATH} \
|
||||
--transtts_path='../transformer_tts/checkpoint' \
|
||||
--transformer_step=160000 \
|
||||
--config_path='config/fastspeech.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--alignments_path=${ALIGNMENTS_PATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--fastspeech_step``.
|
||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||
|
||||
For more help on arguments
|
||||
|
||||
For more help on arguments:
|
||||
``python train.py --help``.
|
||||
|
||||
## Synthesis
|
||||
After training the FastSpeech, audio can be synthesized with ``synthesis.py``.
|
||||
|
||||
After training the FastSpeech, audio can be synthesized by running ``synthesis.py``.
|
||||
|
||||
```bash
|
||||
python synthesis.py \
|
||||
--use_gpu=1 \
|
||||
--alpha=1.0 \
|
||||
--checkpoint_path='checkpoint/' \
|
||||
--fastspeech_step=112000 \
|
||||
--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--config_clarine='../clarinet/configs/config.yaml' \
|
||||
--checkpoint_clarinet='../clarinet/checkpoint/step-500000' \
|
||||
--output='./synthesis' \
|
||||
```
|
||||
|
||||
We use Clarinet to synthesis wav, so it necessary for you to prepare a pre-trained [Clarinet checkpoint](https://paddlespeech.bj.bcebos.com/Parakeet/clarinet_ljspeech_ckpt_1.0.zip).
|
||||
|
||||
Or you can run the script file directly.
|
||||
|
||||
```bash
|
||||
sh synthesis.sh
|
||||
```
|
||||
|
||||
For more help on arguments:
|
||||
For more help on arguments
|
||||
|
||||
``python synthesis.py --help``.
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from scipy.io.wavfile import write
|
||||
from parakeet.g2p.en import text_to_sequence
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
from ruamel import yaml
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from pprint import pprint
|
||||
from collections import OrderedDict
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet import audio
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.models.fastspeech.utils import get_alignment
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_transformer",
|
||||
type=str,
|
||||
help="transformer_tts checkpoint to synthesis")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="./alignments",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def alignments(args):
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
with dg.guard(place):
|
||||
network_cfg = cfg['network']
|
||||
model = TransformerTTS(
|
||||
network_cfg['embedding_size'], network_cfg['hidden_size'],
|
||||
network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
|
||||
cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
|
||||
network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model, checkpoint_path=args.checkpoint_transformer)
|
||||
model.eval()
|
||||
|
||||
# get text data
|
||||
root = Path(args.data)
|
||||
csv_path = root.joinpath("metadata.csv")
|
||||
table = pd.read_csv(
|
||||
csv_path,
|
||||
sep="|",
|
||||
header=None,
|
||||
quoting=csv.QUOTE_NONE,
|
||||
names=["fname", "raw_text", "normalized_text"])
|
||||
ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
num_mels=cfg['audio']['num_mels'],
|
||||
min_level_db=cfg['audio']['min_level_db'],
|
||||
ref_level_db=cfg['audio']['ref_level_db'],
|
||||
n_fft=cfg['audio']['n_fft'],
|
||||
win_length=cfg['audio']['win_length'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
power=cfg['audio']['power'],
|
||||
preemphasis=cfg['audio']['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
|
||||
pbar = tqdm(range(len(table)))
|
||||
alignments = OrderedDict()
|
||||
for i in pbar:
|
||||
fname, raw_text, normalized_text = table.iloc[i]
|
||||
# init input
|
||||
text = np.asarray(text_to_sequence(normalized_text))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||
wav = ljspeech_processor.load_wav(
|
||||
os.path.join(args.data, 'wavs', fname + ".wav"))
|
||||
mel_input = ljspeech_processor.melspectrogram(wav).astype(
|
||||
np.float32)
|
||||
mel_input = np.transpose(mel_input, axes=(1, 0))
|
||||
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
||||
mel_lens = mel_input.shape[1]
|
||||
|
||||
dec_slf_mask = get_triu_tensor(mel_input,
|
||||
mel_input).astype(np.float32)
|
||||
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
|
||||
dec_slf_mask = fluid.layers.cast(
|
||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
||||
mel_input = fluid.layers.concat(
|
||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||
|
||||
alignment, _ = get_alignment(attn_probs, mel_lens,
|
||||
network_cfg['decoder_num_head'])
|
||||
alignments[fname] = alignment
|
||||
with open(args.output + '.txt', "wb") as f:
|
||||
pickle.dump(alignments, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Get alignments from TransformerTTS model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
alignments(args)
|
|
@ -0,0 +1,14 @@
|
|||
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u get_alignments.py \
|
||||
--use_gpu=1 \
|
||||
--output='./alignments' \
|
||||
--data='../../../dataset/LJSpeech-1.1' \
|
||||
--config='../../transformer_tts/configs/ljspeech.yaml' \
|
||||
--checkpoint_transformer='../../transformer_tts/checkpoint/transformer/step-120000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
exit 0
|
|
@ -1,32 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80 #the number of mel bands when calculating mel spectrograms.
|
||||
n_fft: 2048 #the number of fft components.
|
||||
sr: 22050 #the sampling rate of audio data file.
|
||||
preemphasis: 0.97 #the preemphasis coefficient.
|
||||
hop_length: 256 #the number of samples to advance between frames.
|
||||
win_length: 1024 #the length (width) of the window function.
|
||||
power: 1.2 #the power to raise before griffin-lim.
|
||||
min_level_db: -100 #the minimum level db.
|
||||
ref_level_db: 20 #the reference level db.
|
||||
outputs_per_step: 1 #the outputs per step.
|
||||
|
||||
encoder_n_layer: 6 #the number of FFT Block in encoder.
|
||||
encoder_head: 2 #the attention head number in encoder.
|
||||
encoder_conv1d_filter_size: 1536 #the filter size of conv1d in encoder.
|
||||
max_seq_len: 2048 #the max length of sequence.
|
||||
decoder_n_layer: 6 #the number of FFT Block in decoder.
|
||||
decoder_head: 2 #the attention head number in decoder.
|
||||
decoder_conv1d_filter_size: 1536 #the filter size of conv1d in decoder.
|
||||
fs_hidden_size: 384 #the hidden size in model of fastspeech.
|
||||
duration_predictor_output_size: 256 #the output size of duration predictior.
|
||||
duration_predictor_filter_size: 3 #the filter size of conv1d in duration prediction.
|
||||
fft_conv1d_filter: 3 #the filter size of conv1d in fft.
|
||||
fft_conv1d_padding: 1 #the padding size of conv1d in fft.
|
||||
dropout: 0.1 #the dropout in network.
|
||||
transformer_head: 4 #the attention head num of transformerTTS.
|
||||
|
||||
embedding_size: 512 #the dim size of embedding of transformerTTS.
|
||||
hidden_size: 256 #the hidden size in model of transformerTTS.
|
||||
warm_up_step: 4000 #the warm up step of learning rate.
|
||||
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
audio:
|
||||
num_mels: 80 #the number of mel bands when calculating mel spectrograms.
|
||||
n_fft: 2048 #the number of fft components.
|
||||
sr: 22050 #the sampling rate of audio data file.
|
||||
hop_length: 256 #the number of samples to advance between frames.
|
||||
win_length: 1024 #the length (width) of the window function.
|
||||
power: 1.2 #the power to raise before griffin-lim.
|
||||
|
||||
network:
|
||||
encoder_n_layer: 6 #the number of FFT Block in encoder.
|
||||
encoder_head: 2 #the attention head number in encoder.
|
||||
encoder_conv1d_filter_size: 1536 #the filter size of conv1d in encoder.
|
||||
max_seq_len: 2048 #the max length of sequence.
|
||||
decoder_n_layer: 6 #the number of FFT Block in decoder.
|
||||
decoder_head: 2 #the attention head number in decoder.
|
||||
decoder_conv1d_filter_size: 1536 #the filter size of conv1d in decoder.
|
||||
hidden_size: 384 #the hidden size in model of fastspeech.
|
||||
duration_predictor_output_size: 256 #the output size of duration predictior.
|
||||
duration_predictor_filter_size: 3 #the filter size of conv1d in duration prediction.
|
||||
fft_conv1d_filter: 3 #the filter size of conv1d in fft.
|
||||
fft_conv1d_padding: 1 #the padding size of conv1d in fft.
|
||||
dropout: 0.1 #the dropout in network.
|
||||
outputs_per_step: 1
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
learning_rate: 0.001
|
||||
warm_up_step: 4000 #the warm up step of learning rate.
|
||||
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
||||
|
||||
checkpoint_interval: 1000
|
||||
max_epochs: 10000
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
encoder_n_layer: 6
|
||||
encoder_head: 2
|
||||
encoder_conv1d_filter_size: 1536
|
||||
max_seq_len: 2048
|
||||
decoder_n_layer: 6
|
||||
decoder_head: 2
|
||||
decoder_conv1d_filter_size: 1536
|
||||
fs_hidden_size: 384
|
||||
duration_predictor_output_size: 256
|
||||
duration_predictor_filter_size: 3
|
||||
fft_conv1d_filter: 3
|
||||
fft_conv1d_padding: 1
|
||||
dropout: 0.1
|
||||
transformer_head: 4
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import librosa
|
||||
import csv
|
||||
import pickle
|
||||
|
||||
from paddle import fluid
|
||||
from parakeet import g2p
|
||||
from parakeet import audio
|
||||
from parakeet.data.sampler import *
|
||||
from parakeet.data.datacargo import DataCargo
|
||||
from parakeet.data.batch import TextIDBatcher, SpecBatcher
|
||||
from parakeet.data.dataset import DatasetMixin, TransformDataset, CacheDataset, SliceDataset
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
|
||||
|
||||
class LJSpeechLoader:
|
||||
def __init__(self,
|
||||
config,
|
||||
place,
|
||||
data_path,
|
||||
alignments_path,
|
||||
batch_size,
|
||||
nranks,
|
||||
rank,
|
||||
is_vocoder=False,
|
||||
shuffle=True):
|
||||
|
||||
LJSPEECH_ROOT = Path(data_path)
|
||||
metadata = LJSpeechMetaData(LJSPEECH_ROOT, alignments_path)
|
||||
transformer = LJSpeech(
|
||||
sr=config['sr'],
|
||||
n_fft=config['n_fft'],
|
||||
num_mels=config['num_mels'],
|
||||
win_length=config['win_length'],
|
||||
hop_length=config['hop_length'])
|
||||
dataset = TransformDataset(metadata, transformer)
|
||||
dataset = CacheDataset(dataset)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
len(dataset), nranks, rank, shuffle=shuffle)
|
||||
|
||||
assert batch_size % nranks == 0
|
||||
each_bs = batch_size // nranks
|
||||
dataloader = DataCargo(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=each_bs,
|
||||
shuffle=shuffle,
|
||||
batch_fn=batch_examples,
|
||||
drop_last=True)
|
||||
self.reader = fluid.io.DataLoader.from_generator(
|
||||
capacity=32,
|
||||
iterable=True,
|
||||
use_double_buffer=True,
|
||||
return_list=True)
|
||||
self.reader.set_batch_generator(dataloader, place)
|
||||
|
||||
|
||||
class LJSpeechMetaData(DatasetMixin):
|
||||
def __init__(self, root, alignments_path):
|
||||
self.root = Path(root)
|
||||
self._wav_dir = self.root.joinpath("wavs")
|
||||
csv_path = self.root.joinpath("metadata.csv")
|
||||
self._table = pd.read_csv(
|
||||
csv_path,
|
||||
sep="|",
|
||||
header=None,
|
||||
quoting=csv.QUOTE_NONE,
|
||||
names=["fname", "raw_text", "normalized_text"])
|
||||
with open(alignments_path, "rb") as f:
|
||||
self._alignments = pickle.load(f)
|
||||
|
||||
def get_example(self, i):
|
||||
fname, raw_text, normalized_text = self._table.iloc[i]
|
||||
alignment = self._alignments[fname]
|
||||
fname = str(self._wav_dir.joinpath(fname + ".wav"))
|
||||
return fname, normalized_text, alignment
|
||||
|
||||
def __len__(self):
|
||||
return len(self._table)
|
||||
|
||||
|
||||
class LJSpeech(object):
|
||||
def __init__(self,
|
||||
sr=22050,
|
||||
n_fft=2048,
|
||||
num_mels=80,
|
||||
win_length=1024,
|
||||
hop_length=256):
|
||||
super(LJSpeech, self).__init__()
|
||||
self.sr = sr
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(self, metadatum):
|
||||
"""All the code for generating an Example from a metadatum. If you want a
|
||||
different preprocessing pipeline, you can override this method.
|
||||
This method may require several processor, each of which has a lot of options.
|
||||
In this case, you'd better pass a composed transform and pass it to the init
|
||||
method.
|
||||
"""
|
||||
fname, normalized_text, alignment = metadatum
|
||||
|
||||
wav, _ = librosa.load(str(fname))
|
||||
spec = librosa.stft(
|
||||
y=wav,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length)
|
||||
mag = np.abs(spec)
|
||||
mel = librosa.filters.mel(self.sr, self.n_fft, n_mels=self.num_mels)
|
||||
mel = np.matmul(mel, mag)
|
||||
mel = np.log(np.maximum(mel, 1e-5))
|
||||
phonemes = np.array(
|
||||
g2p.en.text_to_sequence(normalized_text), dtype=np.int64)
|
||||
return (mel, phonemes, alignment
|
||||
) # maybe we need to implement it as a map in the future
|
||||
|
||||
|
||||
def batch_examples(batch):
|
||||
texts = []
|
||||
mels = []
|
||||
text_lens = []
|
||||
pos_texts = []
|
||||
pos_mels = []
|
||||
alignments = []
|
||||
for data in batch:
|
||||
mel, text, alignment = data
|
||||
text_lens.append(len(text))
|
||||
pos_texts.append(np.arange(1, len(text) + 1))
|
||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
||||
mels.append(mel)
|
||||
texts.append(text)
|
||||
alignments.append(alignment)
|
||||
|
||||
# Sort by text_len in descending order
|
||||
texts = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(texts, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
mels = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(mels, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
pos_texts = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(pos_texts, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
pos_mels = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(pos_mels, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
alignments = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(alignments, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
#text_lens = sorted(text_lens, reverse=True)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts = TextIDBatcher(pad_id=0)(texts) #(B, T)
|
||||
pos_texts = TextIDBatcher(pad_id=0)(pos_texts) #(B,T)
|
||||
pos_mels = TextIDBatcher(pad_id=0)(pos_mels) #(B,T)
|
||||
alignments = TextIDBatcher(pad_id=0)(alignments).astype(np.float32)
|
||||
mels = np.transpose(
|
||||
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||
|
||||
return (texts, mels, pos_texts, pos_mels, alignments)
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
default='configs/fastspeech.yaml',
|
||||
help="the yaml config file path.")
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, default=32, help="batch size for training.")
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10000,
|
||||
help="the number of epoch for training.")
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="the learning rate for training.")
|
||||
parser.add_argument(
|
||||
'--save_step',
|
||||
type=int,
|
||||
default=500,
|
||||
help="checkpointing interval during training.")
|
||||
parser.add_argument(
|
||||
'--fastspeech_step',
|
||||
type=int,
|
||||
default=70000,
|
||||
help="Global step to restore checkpoint of fastspeech.")
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=int,
|
||||
default=1,
|
||||
help="use gpu or not during training.")
|
||||
parser.add_argument(
|
||||
'--use_data_parallel',
|
||||
type=int,
|
||||
default=0,
|
||||
help="use data parallel or not during training.")
|
||||
parser.add_argument(
|
||||
'--alpha',
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The hyperparameter to determine the length of the expanded sequence \
|
||||
mel, thereby controlling the voice speed.")
|
||||
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='./dataset/LJSpeech-1.1',
|
||||
help="the path of dataset.")
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="the path to load checkpoint or pretrain model.")
|
||||
parser.add_argument(
|
||||
'--save_path',
|
||||
type=str,
|
||||
default='./checkpoint',
|
||||
help="the path to save checkpoint.")
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='./log',
|
||||
help="the directory to save tensorboard log.")
|
||||
parser.add_argument(
|
||||
'--sample_path',
|
||||
type=str,
|
||||
default='./sample',
|
||||
help="the directory to save audio sample in synthesis.")
|
||||
parser.add_argument(
|
||||
'--transtts_path',
|
||||
type=str,
|
||||
default='../transformer_tts/checkpoint',
|
||||
help="the directory to load pretrain transformerTTS model.")
|
||||
parser.add_argument(
|
||||
'--transformer_step',
|
||||
type=int,
|
||||
default=160000,
|
||||
help="the step to load transformerTTS model.")
|
|
@ -13,11 +13,12 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from tensorboardX import SummaryWriter
|
||||
from scipy.io.wavfile import write
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from matplotlib import cm
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
@ -25,93 +26,178 @@ from parakeet.g2p.en import text_to_sequence
|
|||
from parakeet import audio
|
||||
from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||
from parakeet.utils.layer_tools import freeze
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument(
|
||||
"--config_clarinet", type=str, help="path of the clarinet config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=1,
|
||||
help="determine the length of the expanded sequence mel, controlling the voice speed."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, help="fastspeech checkpoint to synthesis")
|
||||
parser.add_argument(
|
||||
"--checkpoint_clarinet",
|
||||
type=str,
|
||||
help="clarinet checkpoint to synthesis")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="synthesis",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def synthesis(text_input, args):
|
||||
place = (fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
|
||||
fluid.enable_dygraph(place)
|
||||
|
||||
# tensorboard
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'synthesis')
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
writer = SummaryWriter(path)
|
||||
# tensorboard
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
with dg.guard(place):
|
||||
model = FastSpeech(cfg)
|
||||
model.set_dict(
|
||||
load_checkpoint(
|
||||
str(args.fastspeech_step),
|
||||
os.path.join(args.checkpoint_path, "fastspeech")))
|
||||
model.eval()
|
||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
||||
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = np.expand_dims(text, axis=0)
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = np.expand_dims(pos_text, axis=0)
|
||||
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
|
||||
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text,
|
||||
text).astype(np.float32)
|
||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model, checkpoint_path=args.checkpoint)
|
||||
model.eval()
|
||||
|
||||
text = dg.to_variable(text)
|
||||
pos_text = dg.to_variable(pos_text)
|
||||
enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
|
||||
enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = np.expand_dims(text, axis=0)
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = np.expand_dims(pos_text, axis=0)
|
||||
|
||||
mel_output, mel_output_postnet = model(
|
||||
text,
|
||||
pos_text,
|
||||
alpha=args.alpha,
|
||||
enc_non_pad_mask=enc_non_pad_mask,
|
||||
enc_slf_attn_mask=enc_slf_attn_mask,
|
||||
dec_non_pad_mask=None,
|
||||
dec_slf_attn_mask=None)
|
||||
text = dg.to_variable(text)
|
||||
pos_text = dg.to_variable(pos_text)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
num_mels=cfg['audio']['num_mels'],
|
||||
min_level_db=cfg['audio']['min_level_db'],
|
||||
ref_level_db=cfg['audio']['ref_level_db'],
|
||||
n_fft=cfg['audio']['n_fft'],
|
||||
win_length=cfg['audio']['win_length'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
power=cfg['audio']['power'],
|
||||
preemphasis=cfg['audio']['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
_, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
|
||||
|
||||
mel_output_postnet = fluid.layers.transpose(
|
||||
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
|
||||
wav = _ljspeech_processor.inv_melspectrogram(mel_output_postnet.numpy(
|
||||
))
|
||||
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
|
||||
print("Synthesis completed !!!")
|
||||
result = np.exp(mel_output_postnet.numpy())
|
||||
mel_output_postnet = fluid.layers.transpose(
|
||||
fluid.layers.squeeze(mel_output_postnet, [0]), [1, 0])
|
||||
mel_output_postnet = np.exp(mel_output_postnet.numpy())
|
||||
basis = librosa.filters.mel(cfg['audio']['sr'], cfg['audio']['n_fft'],
|
||||
cfg['audio']['num_mels'])
|
||||
inv_basis = np.linalg.pinv(basis)
|
||||
spec = np.maximum(1e-10, np.dot(inv_basis, mel_output_postnet))
|
||||
|
||||
# synthesis use clarinet
|
||||
wav_clarinet = synthesis_with_clarinet(
|
||||
args.config_clarinet, args.checkpoint_clarinet, result, place)
|
||||
writer.add_audio(text_input + '(clarinet)', wav_clarinet, 0,
|
||||
cfg['audio']['sr'])
|
||||
if not os.path.exists(os.path.join(args.output, 'samples')):
|
||||
os.mkdir(os.path.join(args.output, 'samples'))
|
||||
write(
|
||||
os.path.join(os.path.join(args.output, 'samples'), 'clarinet.wav'),
|
||||
cfg['audio']['sr'], wav_clarinet)
|
||||
|
||||
#synthesis use griffin-lim
|
||||
wav = librosa.core.griffinlim(
|
||||
spec**cfg['audio']['power'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
win_length=cfg['audio']['win_length'])
|
||||
writer.add_audio(text_input + '(griffin-lim)', wav, 0, cfg['audio']['sr'])
|
||||
write(
|
||||
os.path.join(
|
||||
os.path.join(args.output, 'samples'), 'grinffin-lim.wav'),
|
||||
cfg['audio']['sr'], wav)
|
||||
print("Synthesis completed !!!")
|
||||
writer.close()
|
||||
|
||||
|
||||
def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place):
|
||||
with open(config_path, 'rt') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
data_config = config["data"]
|
||||
n_mels = data_config["n_mels"]
|
||||
|
||||
teacher_config = config["teacher"]
|
||||
n_loop = teacher_config["n_loop"]
|
||||
n_layer = teacher_config["n_layer"]
|
||||
filter_size = teacher_config["filter_size"]
|
||||
|
||||
# only batch=1 for validation is enabled
|
||||
|
||||
with dg.guard(place):
|
||||
# conditioner(upsampling net)
|
||||
conditioner_config = config["conditioner"]
|
||||
upsampling_factors = conditioner_config["upsampling_factors"]
|
||||
upsample_net = UpsampleNet(upscale_factors=upsampling_factors)
|
||||
freeze(upsample_net)
|
||||
|
||||
residual_channels = teacher_config["residual_channels"]
|
||||
loss_type = teacher_config["loss_type"]
|
||||
output_dim = teacher_config["output_dim"]
|
||||
log_scale_min = teacher_config["log_scale_min"]
|
||||
assert loss_type == "mog" and output_dim == 3, \
|
||||
"the teacher wavenet should be a wavenet with single gaussian output"
|
||||
|
||||
teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim,
|
||||
n_mels, filter_size, loss_type, log_scale_min)
|
||||
# load & freeze upsample_net & teacher
|
||||
freeze(teacher)
|
||||
|
||||
student_config = config["student"]
|
||||
n_loops = student_config["n_loops"]
|
||||
n_layers = student_config["n_layers"]
|
||||
student_residual_channels = student_config["residual_channels"]
|
||||
student_filter_size = student_config["filter_size"]
|
||||
student_log_scale_min = student_config["log_scale_min"]
|
||||
student = ParallelWaveNet(n_loops, n_layers, student_residual_channels,
|
||||
n_mels, student_filter_size)
|
||||
|
||||
stft_config = config["stft"]
|
||||
stft = STFT(
|
||||
n_fft=stft_config["n_fft"],
|
||||
hop_length=stft_config["hop_length"],
|
||||
win_length=stft_config["win_length"])
|
||||
|
||||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
io.load_parameters(model=model, checkpoint_path=checkpoint)
|
||||
|
||||
if not os.path.exists(args.output):
|
||||
os.makedirs(args.output)
|
||||
model.eval()
|
||||
|
||||
# Rescale mel_spectrogram.
|
||||
min_level, ref_level = 1e-5, 20 # hard code it
|
||||
mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram))
|
||||
mel_spectrogram = mel_spectrogram - ref_level
|
||||
mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1)
|
||||
|
||||
mel_spectrogram = dg.to_variable(mel_spectrogram)
|
||||
mel_spectrogram = fluid.layers.transpose(mel_spectrogram, [0, 2, 1])
|
||||
|
||||
wav_var = model.synthesis(mel_spectrogram)
|
||||
wav_np = wav_var.numpy()[0]
|
||||
|
||||
return wav_np
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train Fastspeech model")
|
||||
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
synthesis("Transformer model is so fast!", args)
|
||||
pprint(vars(args))
|
||||
synthesis("Simple as this proposition is, it is necessary to be stated,",
|
||||
args)
|
||||
|
|
|
@ -3,10 +3,11 @@
|
|||
python -u synthesis.py \
|
||||
--use_gpu=1 \
|
||||
--alpha=1.0 \
|
||||
--checkpoint_path='checkpoint/' \
|
||||
--fastspeech_step=71000 \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/synthesis.yaml' \
|
||||
--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--config_clarine='../clarinet/configs/config.yaml' \
|
||||
--checkpoint_clarinet='../clarinet/checkpoint/step-500000' \
|
||||
--output='./synthesis' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in synthesis!"
|
||||
|
|
|
@ -17,7 +17,6 @@ import os
|
|||
import time
|
||||
import math
|
||||
from pathlib import Path
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from tqdm import tqdm
|
||||
|
@ -27,162 +26,132 @@ from tensorboardX import SummaryWriter
|
|||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid as fluid
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
from parakeet.models.fastspeech.fastspeech import FastSpeech
|
||||
from parakeet.models.fastspeech.utils import get_alignment
|
||||
import sys
|
||||
sys.path.append("../transformer_tts")
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(
|
||||
os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict, opti_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
parser.add_argument(
|
||||
"--alignments_path", type=str, help="path of alignments")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="experiment",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def main(args):
|
||||
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
nranks = dg.parallel.Env().nranks
|
||||
parallel = nranks > 1
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
|
||||
if args.use_data_parallel else fluid.CUDAPlace(0)
|
||||
if args.use_gpu else fluid.CPUPlace())
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
fluid.enable_dygraph(place)
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'fastspeech')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
writer = SummaryWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
|
||||
with dg.guard(place):
|
||||
with fluid.unique_name.guard():
|
||||
transformer_tts = TransformerTTS(cfg)
|
||||
model_dict, _ = load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.transtts_path, "transformer"))
|
||||
transformer_tts.set_dict(model_dict)
|
||||
transformer_tts.eval()
|
||||
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
|
||||
'grad_clip_thresh']))
|
||||
reader = LJSpeechLoader(
|
||||
cfg['audio'],
|
||||
place,
|
||||
args.data,
|
||||
args.alignments_path,
|
||||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
shuffle=True).reader()
|
||||
|
||||
model = FastSpeech(cfg)
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (
|
||||
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
reader = LJSpeechLoader(
|
||||
cfg, args, nranks, local_rank, shuffle=True).reader()
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
|
||||
iteration=args.iteration,
|
||||
checkpoint_path=args.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(local_rank))
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(
|
||||
str(args.fastspeech_step),
|
||||
os.path.join(args.checkpoint_path, "fastspeech"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.fastspeech_step
|
||||
print("load checkpoint!!!")
|
||||
if parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
if args.use_data_parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
(character, mel, pos_text, pos_mel, alignment) = data
|
||||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
(character, mel, mel_input, pos_text, pos_mel, text_length,
|
||||
mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask,
|
||||
enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data
|
||||
global_step += 1
|
||||
|
||||
_, _, attn_probs, _, _, _ = transformer_tts(
|
||||
character,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask,
|
||||
enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask,
|
||||
dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
alignment, max_attn = get_alignment(attn_probs, mel_lens,
|
||||
cfg['transformer_head'])
|
||||
alignment = dg.to_variable(alignment).astype(np.float32)
|
||||
#Forward
|
||||
result = model(
|
||||
character, pos_text, mel_pos=pos_mel, length_target=alignment)
|
||||
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
|
||||
mel_loss = layers.mse_loss(mel_output, mel)
|
||||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||
duration_loss = layers.mean(
|
||||
layers.abs(
|
||||
layers.elementwise_sub(duration_predictor_output,
|
||||
alignment)))
|
||||
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
||||
|
||||
if local_rank == 0 and global_step % 5 == 1:
|
||||
x = np.uint8(
|
||||
cm.viridis(max_attn[8, :mel_lens.numpy()[8]]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
0,
|
||||
dataformats="HWC")
|
||||
if local_rank == 0:
|
||||
writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
|
||||
writer.add_scalar('post_mel_loss',
|
||||
mel_postnet_loss.numpy(), global_step)
|
||||
writer.add_scalar('duration_loss',
|
||||
duration_loss.numpy(), global_step)
|
||||
writer.add_scalar('learning_rate',
|
||||
optimizer._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
|
||||
global_step += 1
|
||||
if parallel:
|
||||
total_loss = model.scale_loss(total_loss)
|
||||
total_loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
total_loss.backward()
|
||||
optimizer.minimize(total_loss)
|
||||
model.clear_gradients()
|
||||
|
||||
#Forward
|
||||
result = model(
|
||||
character,
|
||||
pos_text,
|
||||
mel_pos=pos_mel,
|
||||
length_target=alignment,
|
||||
enc_non_pad_mask=enc_query_mask,
|
||||
enc_slf_attn_mask=enc_slf_mask,
|
||||
dec_non_pad_mask=dec_query_slf_mask,
|
||||
dec_slf_attn_mask=dec_slf_mask)
|
||||
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
|
||||
mel_loss = layers.mse_loss(mel_output, mel)
|
||||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||
duration_loss = layers.mean(
|
||||
layers.abs(
|
||||
layers.elementwise_sub(duration_predictor_output,
|
||||
alignment)))
|
||||
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
|
||||
if local_rank == 0:
|
||||
writer.add_scalar('mel_loss',
|
||||
mel_loss.numpy(), global_step)
|
||||
writer.add_scalar('post_mel_loss',
|
||||
mel_postnet_loss.numpy(), global_step)
|
||||
writer.add_scalar('duration_loss',
|
||||
duration_loss.numpy(), global_step)
|
||||
writer.add_scalar('learning_rate',
|
||||
optimizer._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
|
||||
if args.use_data_parallel:
|
||||
total_loss = model.scale_loss(total_loss)
|
||||
total_loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
total_loss.backward()
|
||||
optimizer.minimize(
|
||||
total_loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,
|
||||
'fastspeech/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -190,5 +159,5 @@ if __name__ == '__main__':
|
|||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(args)
|
||||
pprint(vars(args))
|
||||
main(args)
|
||||
|
|
|
@ -1,21 +1,12 @@
|
|||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --fastspeech_step
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python -u train.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
--lr=0.001 \
|
||||
--save_step=500 \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path='../../dataset/LJSpeech-1.1' \
|
||||
--transtts_path='../transformer_tts/checkpoint' \
|
||||
--transformer_step=120000 \
|
||||
--save_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/fastspeech.yaml' \
|
||||
#--checkpoint_path='./checkpoint' \
|
||||
#--fastspeech_step=97000 \
|
||||
--data='../../dataset/LJSpeech-1.1' \
|
||||
--alignments_path='./alignments/alignments.txt' \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
#--checkpoint='./checkpoint/fastspeech/step-120000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# TransformerTTS
|
||||
|
||||
PaddlePaddle dynamic graph implementation of TransformerTTS, a neural TTS with Transformer. The implementation is based on [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895).
|
||||
|
||||
## Dataset
|
||||
|
@ -9,7 +10,9 @@ We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://k
|
|||
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
|
||||
tar xjvf LJSpeech-1.1.tar.bz2
|
||||
```
|
||||
|
||||
## Model Architecture
|
||||
|
||||
<div align="center" name="TransformerTTS model architecture">
|
||||
<img src="./images/model_architecture.jpg" width=400 height=600 /> <br>
|
||||
</div>
|
||||
|
@ -20,6 +23,7 @@ TransformerTTS model architecture
|
|||
The model adopts the multi-head attention mechanism to replace the RNN structures and also the original attention mechanism in [Tacotron2](https://arxiv.org/abs/1712.05884). The model consists of two main parts, encoder and decoder. We also implement the CBHG model of Tacotron as the vocoder part and convert the spectrogram into raw wave using Griffin-Lim algorithm.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
├── config # yaml configuration files
|
||||
├── data.py # dataset and dataloader settings for LJSpeech
|
||||
|
@ -28,85 +32,114 @@ The model adopts the multi-head attention mechanism to replace the RNN structure
|
|||
├── train_vocoder.py # script for vocoder model training
|
||||
```
|
||||
|
||||
## Saving & Loading
|
||||
|
||||
`train_transformer.py` and `train_vocoer.py` have 3 arguments in common, `--checkpoint`, `--iteration` and `--output`.
|
||||
|
||||
1. `--output` is the directory for saving results.
|
||||
During training, checkpoints are saved in `${output}/checkpoints` and tensorboard logs are saved in `${output}/log`.
|
||||
During synthesis, results are saved in `${output}/samples` and tensorboard log is save in `${output}/log`.
|
||||
|
||||
2. `--checkpoint` is the path of a checkpoint and `--iteration` is the target step. They are used to load checkpoints in the following way.
|
||||
|
||||
- If `--checkpoint` is provided, the checkpoint specified by `--checkpoint` is loaded.
|
||||
|
||||
- If `--checkpoint` is not provided, we try to load the checkpoint of the target step specified by `--iteration` from the `${output}/checkpoints/` directory, e.g. if given `--iteration 120000`, the checkpoint `${output}/checkpoints/step-120000.*` will be load.
|
||||
|
||||
- If both `--checkpoint` and `--iteration` are not provided, we try to load the latest checkpoint from `${output}/checkpoints/` directory.
|
||||
|
||||
## Train Transformer
|
||||
|
||||
TransformerTTS model can be trained with ``train_transformer.py``.
|
||||
TransformerTTS model can be trained by running ``train_transformer.py``.
|
||||
|
||||
```bash
|
||||
python train_trasformer.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_transformer.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
Or you can run the script file directly.
|
||||
|
||||
```bash
|
||||
sh train_transformer.sh
|
||||
```
|
||||
If you want to train on multiple GPUs, you must set ``--use_data_parallel=1``, and then start training as follows:
|
||||
|
||||
If you want to train on multiple GPUs, you must start training in the following way.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_transformer.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=1 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_transformer.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--transformer_step``.
|
||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||
|
||||
**Note: In order to ensure the training effect, we recommend using multi-GPU training to enlarge the batch size, and at least 16 samples in single batch per GPU.**
|
||||
|
||||
For more help on arguments:
|
||||
For more help on arguments
|
||||
|
||||
``python train_transformer.py --help``.
|
||||
|
||||
## Train Vocoder
|
||||
Vocoder model can be trained with ``train_vocoder.py``.
|
||||
|
||||
Vocoder model can be trained by running ``train_vocoder.py``.
|
||||
|
||||
```bash
|
||||
python train_vocoder.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_vocoder.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./vocoder' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
|
||||
Or you can run the script file directly.
|
||||
|
||||
```bash
|
||||
sh train_vocoder.sh
|
||||
```
|
||||
If you want to train on multiple GPUs, you must set ``--use_data_parallel=1``, and then start training as follows:
|
||||
|
||||
If you want to train on multiple GPUs, you must start training in the following way.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train_vocoder.py \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=1 \
|
||||
--data_path=${DATAPATH} \
|
||||
--config_path='config/train_vocoder.yaml' \
|
||||
--data=${DATAPATH} \
|
||||
--output='./vocoder' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
```
|
||||
If you wish to resume from an existing model, please set ``--checkpoint_path`` and ``--vocoder_step``.
|
||||
|
||||
For more help on arguments:
|
||||
If you wish to resume from an existing model, See [Saving-&-Loading](#Saving-&-Loading) for details of checkpoint loading.
|
||||
|
||||
For more help on arguments
|
||||
|
||||
``python train_vocoder.py --help``.
|
||||
|
||||
## Synthesis
|
||||
After training the TransformerTTS and vocoder model, audio can be synthesized with ``synthesis.py``.
|
||||
|
||||
After training the TransformerTTS and vocoder model, audio can be synthesized by running ``synthesis.py``.
|
||||
|
||||
```bash
|
||||
python synthesis.py \
|
||||
--max_len=50 \
|
||||
--transformer_step=160000 \
|
||||
--vocoder_step=70000 \
|
||||
--use_gpu=1
|
||||
--checkpoint_path='./checkpoint' \
|
||||
--sample_path='./sample' \
|
||||
--config_path='config/synthesis.yaml' \
|
||||
--max_len=300 \
|
||||
--use_gpu=1 \
|
||||
--output='./synthesis' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--checkpoint_transformer='./checkpoint/transformer/step-120000' \
|
||||
--checkpoint_vocoder='./checkpoint/vocoder/step-100000' \
|
||||
```
|
||||
|
||||
Or you can run the script file directly.
|
||||
|
||||
```bash
|
||||
sh synthesis.sh
|
||||
```
|
||||
|
||||
And the audio file will be saved in ``--sample_path``.
|
||||
For more help on arguments
|
||||
|
||||
For more help on arguments:
|
||||
``python synthesis.py --help``.
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 256 #275
|
||||
win_length: 1024 #1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
|
||||
network:
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
encoder_num_head: 4
|
||||
encoder_n_layers: 3
|
||||
decoder_num_head: 4
|
||||
decoder_n_layers: 3
|
||||
outputs_per_step: 1
|
||||
stop_token: False
|
||||
|
||||
vocoder:
|
||||
hidden_size: 256
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
learning_rate: 0.001
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
||||
|
||||
checkpoint_interval: 1000
|
||||
image_interval: 2000
|
||||
|
||||
max_epochs: 10000
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
|
@ -1,20 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
||||
|
||||
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
hidden_size: 256
|
||||
embedding_size: 512
|
||||
warm_up_step: 4000
|
||||
grad_clip_thresh: 1.0
|
|
@ -30,14 +30,15 @@ from parakeet.models.transformer_tts.utils import *
|
|||
class LJSpeechLoader:
|
||||
def __init__(self,
|
||||
config,
|
||||
args,
|
||||
place,
|
||||
data_path,
|
||||
batch_size,
|
||||
nranks,
|
||||
rank,
|
||||
is_vocoder=False,
|
||||
shuffle=True):
|
||||
place = fluid.CUDAPlace(rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
LJSPEECH_ROOT = Path(args.data_path)
|
||||
LJSPEECH_ROOT = Path(data_path)
|
||||
metadata = LJSpeechMetaData(LJSPEECH_ROOT)
|
||||
transformer = LJSpeech(config)
|
||||
dataset = TransformDataset(metadata, transformer)
|
||||
|
@ -46,8 +47,8 @@ class LJSpeechLoader:
|
|||
sampler = DistributedSampler(
|
||||
len(dataset), nranks, rank, shuffle=shuffle)
|
||||
|
||||
assert args.batch_size % nranks == 0
|
||||
each_bs = args.batch_size // nranks
|
||||
assert batch_size % nranks == 0
|
||||
each_bs = batch_size // nranks
|
||||
if is_vocoder:
|
||||
dataloader = DataCargo(
|
||||
dataset,
|
||||
|
@ -98,15 +99,15 @@ class LJSpeech(object):
|
|||
super(LJSpeech, self).__init__()
|
||||
self.config = config
|
||||
self._ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=config['audio']['sr'],
|
||||
num_mels=config['audio']['num_mels'],
|
||||
min_level_db=config['audio']['min_level_db'],
|
||||
ref_level_db=config['audio']['ref_level_db'],
|
||||
n_fft=config['audio']['n_fft'],
|
||||
win_length=config['audio']['win_length'],
|
||||
hop_length=config['audio']['hop_length'],
|
||||
power=config['audio']['power'],
|
||||
preemphasis=config['audio']['preemphasis'],
|
||||
sample_rate=config['sr'],
|
||||
num_mels=config['num_mels'],
|
||||
min_level_db=config['min_level_db'],
|
||||
ref_level_db=config['ref_level_db'],
|
||||
n_fft=config['n_fft'],
|
||||
win_length=config['win_length'],
|
||||
hop_length=config['hop_length'],
|
||||
power=config['power'],
|
||||
preemphasis=config['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
|
@ -140,7 +141,6 @@ def batch_examples(batch):
|
|||
texts = []
|
||||
mels = []
|
||||
mel_inputs = []
|
||||
mel_lens = []
|
||||
text_lens = []
|
||||
pos_texts = []
|
||||
pos_mels = []
|
||||
|
@ -150,7 +150,6 @@ def batch_examples(batch):
|
|||
np.concatenate(
|
||||
[np.zeros([mel.shape[0], 1], np.float32), mel[:, :-1]],
|
||||
axis=-1))
|
||||
mel_lens.append(mel.shape[1])
|
||||
text_lens.append(len(text))
|
||||
pos_texts.append(np.arange(1, len(text) + 1))
|
||||
pos_mels.append(np.arange(1, mel.shape[1] + 1))
|
||||
|
@ -173,11 +172,6 @@ def batch_examples(batch):
|
|||
for i, _ in sorted(
|
||||
zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
mel_lens = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
pos_texts = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
|
@ -199,18 +193,7 @@ def batch_examples(batch):
|
|||
mel_inputs = np.transpose(
|
||||
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||
|
||||
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32)
|
||||
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
|
||||
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,
|
||||
mel_inputs).astype(np.float32)
|
||||
enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0]).astype(
|
||||
np.float32)
|
||||
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
|
||||
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens),
|
||||
np.array(mel_lens), enc_slf_mask, enc_query_mask, dec_slf_mask,
|
||||
enc_dec_mask, dec_query_slf_mask, dec_query_mask)
|
||||
return (texts, mels, mel_inputs, pos_texts, pos_mels)
|
||||
|
||||
|
||||
def batch_examples_vocoder(batch):
|
||||
|
|
|
@ -1,100 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
default='configs/train_transformer.yaml',
|
||||
help="the yaml config file path.")
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, default=32, help="batch size for training.")
|
||||
parser.add_argument(
|
||||
'--epochs',
|
||||
type=int,
|
||||
default=10000,
|
||||
help="the number of epoch for training.")
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="the learning rate for training.")
|
||||
parser.add_argument(
|
||||
'--save_step',
|
||||
type=int,
|
||||
default=500,
|
||||
help="checkpointing interval during training.")
|
||||
parser.add_argument(
|
||||
'--image_step',
|
||||
type=int,
|
||||
default=2000,
|
||||
help="attention image interval during training.")
|
||||
parser.add_argument(
|
||||
'--max_len',
|
||||
type=int,
|
||||
default=400,
|
||||
help="The max length of audio when synthsis.")
|
||||
parser.add_argument(
|
||||
'--transformer_step',
|
||||
type=int,
|
||||
default=160000,
|
||||
help="Global step to restore checkpoint of transformer.")
|
||||
parser.add_argument(
|
||||
'--vocoder_step',
|
||||
type=int,
|
||||
default=90000,
|
||||
help="Global step to restore checkpoint of postnet.")
|
||||
parser.add_argument(
|
||||
'--use_gpu',
|
||||
type=int,
|
||||
default=1,
|
||||
help="use gpu or not during training.")
|
||||
parser.add_argument(
|
||||
'--use_data_parallel',
|
||||
type=int,
|
||||
default=0,
|
||||
help="use data parallel or not during training.")
|
||||
parser.add_argument(
|
||||
'--stop_token',
|
||||
type=int,
|
||||
default=0,
|
||||
help="use stop token loss in network or not.")
|
||||
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='./dataset/LJSpeech-1.1',
|
||||
help="the path of dataset.")
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="the path to load checkpoint or pretrain model.")
|
||||
parser.add_argument(
|
||||
'--save_path',
|
||||
type=str,
|
||||
default='./checkpoint',
|
||||
help="the path to save checkpoint.")
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='./log',
|
||||
help="the directory to save tensorboard log.")
|
||||
parser.add_argument(
|
||||
'--sample_path',
|
||||
type=str,
|
||||
default='./sample',
|
||||
help="the directory to save audio sample in synthesis.")
|
|
@ -13,143 +13,153 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from scipy.io.wavfile import write
|
||||
from parakeet.g2p.en import text_to_sequence
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from matplotlib import cm
|
||||
from tensorboardX import SummaryWriter
|
||||
from ruamel import yaml
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from collections import OrderedDict
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
from parakeet.g2p.en import text_to_sequence
|
||||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet import audio
|
||||
from parakeet.models.transformer_tts.vocoder import Vocoder
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
from parakeet.models.transformer_tts import Vocoder
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, _ = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=200,
|
||||
help="The max length of audio when synthsis.")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_transformer",
|
||||
type=str,
|
||||
help="transformer_tts checkpoint to synthesis")
|
||||
parser.add_argument(
|
||||
"--checkpoint_vocoder",
|
||||
type=str,
|
||||
help="vocoder checkpoint to synthesis")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="synthesis",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def synthesis(text_input, args):
|
||||
place = (fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
place = (fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace())
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
# tensorboard
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'synthesis')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path)
|
||||
writer = SummaryWriter(os.path.join(args.output, 'log'))
|
||||
|
||||
with dg.guard(place):
|
||||
with fluid.unique_name.guard():
|
||||
model = TransformerTTS(cfg)
|
||||
model.set_dict(
|
||||
load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.checkpoint_path, "transformer")))
|
||||
model.eval()
|
||||
fluid.enable_dygraph(place)
|
||||
with fluid.unique_name.guard():
|
||||
network_cfg = cfg['network']
|
||||
model = TransformerTTS(
|
||||
network_cfg['embedding_size'], network_cfg['hidden_size'],
|
||||
network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
|
||||
cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
|
||||
network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model, checkpoint_path=args.checkpoint_transformer)
|
||||
model.eval()
|
||||
|
||||
with fluid.unique_name.guard():
|
||||
model_vocoder = Vocoder(cfg, args.batch_size)
|
||||
model_vocoder.set_dict(
|
||||
load_checkpoint(
|
||||
str(args.vocoder_step),
|
||||
os.path.join(args.checkpoint_path, "vocoder")))
|
||||
model_vocoder.eval()
|
||||
# init input
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
||||
mel_input = dg.to_variable(np.zeros([1, 1, 80])).astype(np.float32)
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||
with fluid.unique_name.guard():
|
||||
model_vocoder = Vocoder(
|
||||
cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
||||
cfg['audio']['num_mels'], cfg['audio']['n_fft'])
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model_vocoder, checkpoint_path=args.checkpoint_vocoder)
|
||||
model_vocoder.eval()
|
||||
# init input
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text), [0])
|
||||
mel_input = dg.to_variable(np.zeros([1, 1, 80])).astype(np.float32)
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = fluid.layers.unsqueeze(dg.to_variable(pos_text), [0])
|
||||
|
||||
pbar = tqdm(range(args.max_len))
|
||||
for i in pbar:
|
||||
dec_slf_mask = get_triu_tensor(
|
||||
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
|
||||
dec_slf_mask = fluid.layers.cast(
|
||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
||||
mel_input = fluid.layers.concat(
|
||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||
pbar = tqdm(range(args.max_len))
|
||||
for i in pbar:
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
text, mel_input, pos_text, pos_mel)
|
||||
mel_input = fluid.layers.concat(
|
||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||
|
||||
mag_pred = model_vocoder(postnet_pred)
|
||||
mag_pred = model_vocoder(postnet_pred)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
num_mels=cfg['audio']['num_mels'],
|
||||
min_level_db=cfg['audio']['min_level_db'],
|
||||
ref_level_db=cfg['audio']['ref_level_db'],
|
||||
n_fft=cfg['audio']['n_fft'],
|
||||
win_length=cfg['audio']['win_length'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
power=cfg['audio']['power'],
|
||||
preemphasis=cfg['audio']['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
num_mels=cfg['audio']['num_mels'],
|
||||
min_level_db=cfg['audio']['min_level_db'],
|
||||
ref_level_db=cfg['audio']['ref_level_db'],
|
||||
n_fft=cfg['audio']['n_fft'],
|
||||
win_length=cfg['audio']['win_length'],
|
||||
hop_length=cfg['audio']['hop_length'],
|
||||
power=cfg['audio']['power'],
|
||||
preemphasis=cfg['audio']['preemphasis'],
|
||||
signal_norm=True,
|
||||
symmetric_norm=False,
|
||||
max_norm=1.,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=60,
|
||||
do_trim_silence=False,
|
||||
sound_norm=False)
|
||||
|
||||
wav = _ljspeech_processor.inv_spectrogram(
|
||||
fluid.layers.transpose(
|
||||
fluid.layers.squeeze(mag_pred, [0]), [1, 0]).numpy())
|
||||
global_step = 0
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
# synthesis with cbhg
|
||||
wav = _ljspeech_processor.inv_spectrogram(
|
||||
fluid.layers.transpose(fluid.layers.squeeze(mag_pred, [0]), [1, 0])
|
||||
.numpy())
|
||||
global_step = 0
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
writer.add_audio(text_input + '(cbhg)', wav, 0, cfg['audio']['sr'])
|
||||
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(4):
|
||||
x = np.uint8(cm.viridis(prob.numpy()[j]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
writer.add_audio(text_input, wav, 0, cfg['audio']['sr'])
|
||||
if not os.path.exists(args.sample_path):
|
||||
os.mkdir(args.sample_path)
|
||||
write(
|
||||
os.path.join(args.sample_path, 'test.wav'), cfg['audio']['sr'],
|
||||
wav)
|
||||
if not os.path.exists(os.path.join(args.output, 'samples')):
|
||||
os.mkdir(os.path.join(args.output, 'samples'))
|
||||
write(
|
||||
os.path.join(os.path.join(args.output, 'samples'), 'cbhg.wav'),
|
||||
cfg['audio']['sr'], wav)
|
||||
|
||||
# synthesis with griffin-lim
|
||||
wav = _ljspeech_processor.inv_melspectrogram(
|
||||
fluid.layers.transpose(
|
||||
fluid.layers.squeeze(postnet_pred, [0]), [1, 0]).numpy())
|
||||
writer.add_audio(text_input + '(griffin)', wav, 0, cfg['audio']['sr'])
|
||||
|
||||
write(
|
||||
os.path.join(os.path.join(args.output, 'samples'), 'griffin.wav'),
|
||||
cfg['audio']['sr'], wav)
|
||||
print("Synthesis completed !!!")
|
||||
writer.close()
|
||||
|
||||
|
||||
|
@ -157,5 +167,7 @@ if __name__ == '__main__':
|
|||
parser = argparse.ArgumentParser(description="Synthesis model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(vars(args))
|
||||
synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.",
|
||||
args)
|
||||
|
|
|
@ -3,13 +3,11 @@
|
|||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u synthesis.py \
|
||||
--max_len=300 \
|
||||
--transformer_step=120000 \
|
||||
--vocoder_step=100000 \
|
||||
--use_gpu=1 \
|
||||
--checkpoint_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--sample_path='./sample' \
|
||||
--config_path='configs/synthesis.yaml' \
|
||||
--output='./synthesis' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
--checkpoint_transformer='./checkpoint/transformer/step-120000' \
|
||||
--checkpoint_vocoder='./checkpoint/vocoder/step-100000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
|
|
|
@ -16,7 +16,6 @@ from tqdm import tqdm
|
|||
from tensorboardX import SummaryWriter
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
from ruamel import yaml
|
||||
from matplotlib import cm
|
||||
|
@ -26,186 +25,191 @@ import paddle.fluid.dygraph as dg
|
|||
import paddle.fluid.layers as layers
|
||||
from parakeet.models.transformer_tts.utils import cross_entropy
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(
|
||||
os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict, opti_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="experiment",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def main(args):
|
||||
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
nranks = dg.parallel.Env().nranks
|
||||
parallel = nranks > 1
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
|
||||
if args.use_data_parallel else fluid.CUDAPlace(0)
|
||||
if args.use_gpu else fluid.CPUPlace())
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'transformer')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
writer = SummaryWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
|
||||
with dg.guard(place):
|
||||
model = TransformerTTS(cfg)
|
||||
fluid.enable_dygraph(place)
|
||||
network_cfg = cfg['network']
|
||||
model = TransformerTTS(
|
||||
network_cfg['embedding_size'], network_cfg['hidden_size'],
|
||||
network_cfg['encoder_num_head'], network_cfg['encoder_n_layers'],
|
||||
cfg['audio']['num_mels'], network_cfg['outputs_per_step'],
|
||||
network_cfg['decoder_num_head'], network_cfg['decoder_n_layers'])
|
||||
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (
|
||||
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
|
||||
'grad_clip_thresh']))
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(
|
||||
str(args.transformer_step),
|
||||
os.path.join(args.checkpoint_path, "transformer"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.transformer_step
|
||||
print("load checkpoint!!!")
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
|
||||
iteration=args.iteration,
|
||||
checkpoint_path=args.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(local_rank))
|
||||
|
||||
if args.use_data_parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
if parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
reader = LJSpeechLoader(
|
||||
cfg, args, nranks, local_rank, shuffle=True).reader()
|
||||
reader = LJSpeechLoader(
|
||||
cfg['audio'],
|
||||
place,
|
||||
args.data,
|
||||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
shuffle=True).reader()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel = data
|
||||
|
||||
global_step += 1
|
||||
global_step += 1
|
||||
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
character,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask=dec_slf_mask,
|
||||
enc_slf_mask=enc_slf_mask,
|
||||
enc_query_mask=enc_query_mask,
|
||||
enc_dec_mask=enc_dec_mask,
|
||||
dec_query_slf_mask=dec_query_slf_mask,
|
||||
dec_query_mask=dec_query_mask)
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
character, mel_input, pos_text, pos_mel)
|
||||
|
||||
mel_loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
||||
post_mel_loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||
loss = mel_loss + post_mel_loss
|
||||
mel_loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(mel_pred, mel)))
|
||||
post_mel_loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
|
||||
loss = mel_loss + post_mel_loss
|
||||
|
||||
# Note: When used stop token loss the learning did not work.
|
||||
if args.stop_token:
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
stop_loss = cross_entropy(stop_preds, label)
|
||||
loss = loss + stop_loss
|
||||
# Note: When used stop token loss the learning did not work.
|
||||
if cfg['network']['stop_token']:
|
||||
label = (pos_mel == 0).astype(np.float32)
|
||||
stop_loss = cross_entropy(stop_preds, label)
|
||||
loss = loss + stop_loss
|
||||
|
||||
if local_rank == 0:
|
||||
writer.add_scalars('training_loss', {
|
||||
'mel_loss': mel_loss.numpy(),
|
||||
'post_mel_loss': post_mel_loss.numpy()
|
||||
if local_rank == 0:
|
||||
writer.add_scalars('training_loss', {
|
||||
'mel_loss': mel_loss.numpy(),
|
||||
'post_mel_loss': post_mel_loss.numpy()
|
||||
}, global_step)
|
||||
|
||||
if cfg['network']['stop_token']:
|
||||
writer.add_scalar('stop_loss',
|
||||
stop_loss.numpy(), global_step)
|
||||
|
||||
if parallel:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha': model._layers.encoder.alpha.numpy(),
|
||||
'decoder_alpha': model._layers.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
else:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha': model.encoder.alpha.numpy(),
|
||||
'decoder_alpha': model.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
|
||||
if args.stop_token:
|
||||
writer.add_scalar('stop_loss',
|
||||
stop_loss.numpy(), global_step)
|
||||
writer.add_scalar('learning_rate',
|
||||
optimizer._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
|
||||
if args.use_data_parallel:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha':
|
||||
model._layers.encoder.alpha.numpy(),
|
||||
'decoder_alpha':
|
||||
model._layers.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
else:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha': model.encoder.alpha.numpy(),
|
||||
'decoder_alpha': model.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
if global_step % cfg['train']['image_interval'] == 1:
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(cfg['network']['decoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
writer.add_scalar('learning_rate',
|
||||
optimizer._learning_rate.step().numpy(),
|
||||
global_step)
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(cfg['network']['encoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
if global_step % args.image_step == 1:
|
||||
for i, prob in enumerate(attn_probs):
|
||||
for j in range(4):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * args.batch_size
|
||||
// 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(cfg['network']['decoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
|
||||
for i, prob in enumerate(attn_enc):
|
||||
for j in range(4):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * args.batch_size
|
||||
// 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
if parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.minimize(loss)
|
||||
model.clear_gradients()
|
||||
|
||||
for i, prob in enumerate(attn_dec):
|
||||
for j in range(4):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * args.batch_size
|
||||
// 2]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
i * 4 + j,
|
||||
dataformats="HWC")
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
|
||||
if args.use_data_parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.minimize(
|
||||
loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,
|
||||
'transformer/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train TransformerTTS model")
|
||||
add_config_options_to_parser(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(args)
|
||||
pprint(vars(args))
|
||||
main(args)
|
||||
|
|
|
@ -1,22 +1,12 @@
|
|||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --transformer_step
|
||||
export CUDA_VISIBLE_DEVICES=2
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python -u train_transformer.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
--lr=0.001 \
|
||||
--save_step=1000 \
|
||||
--image_step=2000 \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--stop_token=0 \
|
||||
--data_path='../../dataset/LJSpeech-1.1' \
|
||||
--save_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/train_transformer.yaml' \
|
||||
#--checkpoint_path='./checkpoint' \
|
||||
#--transformer_step=160000 \
|
||||
--data='../../dataset/LJSpeech-1.1' \
|
||||
--output='./experiment' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
#--checkpoint='./checkpoint/transformer/step-120000' \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
|
|
|
@ -18,110 +18,121 @@ from pathlib import Path
|
|||
from collections import OrderedDict
|
||||
import argparse
|
||||
from ruamel import yaml
|
||||
from parse import add_config_options_to_parser
|
||||
from pprint import pprint
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
from data import LJSpeechLoader
|
||||
from parakeet.models.transformer_tts.vocoder import Vocoder
|
||||
from parakeet.models.transformer_tts import Vocoder
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = dg.load_dygraph(os.path.join(model_path, step))
|
||||
new_state_dict = OrderedDict()
|
||||
for param in model_dict:
|
||||
if param.startswith('_layers.'):
|
||||
new_state_dict[param[8:]] = model_dict[param]
|
||||
else:
|
||||
new_state_dict[param] = model_dict[param]
|
||||
return new_state_dict, opti_dict
|
||||
def add_config_options_to_parser(parser):
|
||||
parser.add_argument("--config", type=str, help="path of the config file")
|
||||
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
|
||||
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
|
||||
|
||||
g = parser.add_mutually_exclusive_group()
|
||||
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
|
||||
g.add_argument(
|
||||
"--iteration",
|
||||
type=int,
|
||||
help="the iteration of the checkpoint to load from output directory")
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="vocoder",
|
||||
help="path to save experiment results")
|
||||
|
||||
|
||||
def main(args):
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
nranks = dg.parallel.Env().nranks
|
||||
parallel = nranks > 1
|
||||
|
||||
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
|
||||
|
||||
with open(args.config_path) as f:
|
||||
with open(args.config) as f:
|
||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
|
||||
if args.use_data_parallel else fluid.CUDAPlace(0)
|
||||
if args.use_gpu else fluid.CPUPlace())
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir, 'vocoder')
|
||||
if not os.path.exists(args.output):
|
||||
os.mkdir(args.output)
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
writer = SummaryWriter(os.path.join(args.output,
|
||||
'log')) if local_rank == 0 else None
|
||||
|
||||
with dg.guard(place):
|
||||
model = Vocoder(cfg, args.batch_size)
|
||||
fluid.enable_dygraph(place)
|
||||
model = Vocoder(cfg['train']['batch_size'], cfg['vocoder']['hidden_size'],
|
||||
cfg['audio']['num_mels'], cfg['audio']['n_fft'])
|
||||
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (
|
||||
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
model.train()
|
||||
optimizer = fluid.optimizer.AdamOptimizer(
|
||||
learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
|
||||
'grad_clip_thresh']))
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(
|
||||
str(args.vocoder_step),
|
||||
os.path.join(args.checkpoint_path, "vocoder"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.vocoder_step
|
||||
print("load checkpoint!!!")
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
|
||||
iteration=args.iteration,
|
||||
checkpoint_path=args.checkpoint)
|
||||
print("Rank {}: checkpoint loaded.".format(local_rank))
|
||||
|
||||
if args.use_data_parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
if parallel:
|
||||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
reader = LJSpeechLoader(
|
||||
cfg, args, nranks, local_rank, is_vocoder=True).reader()
|
||||
reader = LJSpeechLoader(
|
||||
cfg['audio'],
|
||||
place,
|
||||
args.data,
|
||||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
is_vocoder=True).reader()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
mel, mag = data
|
||||
mag = dg.to_variable(mag.numpy())
|
||||
mel = dg.to_variable(mel.numpy())
|
||||
global_step += 1
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
mel, mag = data
|
||||
mag = dg.to_variable(mag.numpy())
|
||||
mel = dg.to_variable(mel.numpy())
|
||||
global_step += 1
|
||||
|
||||
mag_pred = model(mel)
|
||||
loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(mag_pred, mag)))
|
||||
mag_pred = model(mel)
|
||||
loss = layers.mean(
|
||||
layers.abs(layers.elementwise_sub(mag_pred, mag)))
|
||||
|
||||
if args.use_data_parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.minimize(
|
||||
loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'grad_clip_thresh']))
|
||||
model.clear_gradients()
|
||||
if parallel:
|
||||
loss = model.scale_loss(loss)
|
||||
loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.minimize(loss)
|
||||
model.clear_gradients()
|
||||
|
||||
if local_rank == 0:
|
||||
writer.add_scalars('training_loss', {
|
||||
'loss': loss.numpy(),
|
||||
}, global_step)
|
||||
if local_rank == 0:
|
||||
writer.add_scalars('training_loss', {'loss': loss.numpy(), },
|
||||
global_step)
|
||||
|
||||
if global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,
|
||||
'vocoder/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
# save checkpoint
|
||||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,20 +1,12 @@
|
|||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --checkpoint_path and --vocoder_step
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u train_vocoder.py \
|
||||
--batch_size=32 \
|
||||
--epochs=10000 \
|
||||
--lr=0.001 \
|
||||
--save_step=1000 \
|
||||
--use_gpu=1 \
|
||||
--use_data_parallel=0 \
|
||||
--data_path='../../dataset/LJSpeech-1.1' \
|
||||
--save_path='./checkpoint' \
|
||||
--log_dir='./log' \
|
||||
--config_path='configs/train_vocoder.yaml' \
|
||||
#--checkpoint_path='./checkpoint' \
|
||||
#--vocoder_step=27000 \
|
||||
--data='../../dataset/LJSpeech-1.1' \
|
||||
--output='./vocoder' \
|
||||
--config='configs/ljspeech.yaml' \
|
||||
#--checkpoint='./checkpoint/vocoder/step-100000' \
|
||||
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
|
|
|
@ -70,25 +70,35 @@ class Decoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
||||
def forward(self, enc_seq, enc_pos):
|
||||
"""
|
||||
Compute decoder outputs.
|
||||
|
||||
Args:
|
||||
enc_seq (Variable): shape(B, T_text, C), dtype float32,
|
||||
the output of length regulator, where T_text means the timesteps of input text,
|
||||
enc_seq (Variable): shape(B, T_mel, C), dtype float32,
|
||||
the output of length regulator, where T_mel means the timesteps of input spectrum.
|
||||
enc_pos (Variable): shape(B, T_mel), dtype int64,
|
||||
the spectrum position, where T_mel means the timesteps of input spectrum,
|
||||
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
the spectrum position.
|
||||
|
||||
Returns:
|
||||
dec_output (Variable): shape(B, T_mel, C), the decoder output.
|
||||
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
slf_attn_mask = get_dec_attn_key_pad_mask(enc_pos, self.n_head,
|
||||
enc_seq.dtype)
|
||||
|
||||
else:
|
||||
len_q = enc_seq.shape[1]
|
||||
slf_attn_mask = layers.triu(
|
||||
layers.ones(
|
||||
shape=[len_q, len_q], dtype=enc_seq.dtype),
|
||||
diagonal=1)
|
||||
slf_attn_mask = layers.cast(
|
||||
slf_attn_mask != 0, dtype=enc_seq.dtype) * -1e30
|
||||
|
||||
non_pad_mask = get_non_pad_mask(enc_pos, 1, enc_seq.dtype)
|
||||
|
||||
# -- Forward
|
||||
dec_output = enc_seq + self.position_enc(enc_pos)
|
||||
|
|
|
@ -76,7 +76,7 @@ class Encoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
||||
def forward(self, character, text_pos):
|
||||
"""
|
||||
Encode text sequence.
|
||||
|
||||
|
@ -84,22 +84,21 @@ class Encoder(dg.Layer):
|
|||
character (Variable): shape(B, T_text), dtype float32, the input text characters,
|
||||
where T_text means the timesteps of input characters,
|
||||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
|
||||
the mask of input characters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
enc_output (Variable): shape(B, T_text, C), the encoder output.
|
||||
non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
|
||||
enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
||||
# -- Forward
|
||||
enc_output = self.src_word_emb(character) + self.position_enc(
|
||||
text_pos) #(N, T, C)
|
||||
|
||||
slf_attn_mask = get_attn_key_pad_mask(text_pos, self.n_head,
|
||||
enc_output.dtype)
|
||||
non_pad_mask = get_non_pad_mask(text_pos, 1, enc_output.dtype)
|
||||
|
||||
for enc_layer in self.layer_stack:
|
||||
enc_output, enc_slf_attn = enc_layer(
|
||||
enc_output,
|
||||
|
|
|
@ -24,11 +24,13 @@ from parakeet.models.fastspeech.decoder import Decoder
|
|||
|
||||
|
||||
class FastSpeech(dg.Layer):
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, num_mels=80):
|
||||
"""FastSpeech model.
|
||||
|
||||
Args:
|
||||
cfg: the yaml configs used in FastSpeech model.
|
||||
num_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
|
||||
"""
|
||||
super(FastSpeech, self).__init__()
|
||||
|
||||
|
@ -37,15 +39,15 @@ class FastSpeech(dg.Layer):
|
|||
len_max_seq=cfg['max_seq_len'],
|
||||
n_layers=cfg['encoder_n_layer'],
|
||||
n_head=cfg['encoder_head'],
|
||||
d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_q=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_model=cfg['fs_hidden_size'],
|
||||
d_k=cfg['hidden_size'] // cfg['encoder_head'],
|
||||
d_q=cfg['hidden_size'] // cfg['encoder_head'],
|
||||
d_model=cfg['hidden_size'],
|
||||
d_inner=cfg['encoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
fft_conv1d_padding=cfg['fft_conv1d_padding'],
|
||||
dropout=0.1)
|
||||
self.length_regulator = LengthRegulator(
|
||||
input_size=cfg['fs_hidden_size'],
|
||||
input_size=cfg['hidden_size'],
|
||||
out_channels=cfg['duration_predictor_output_size'],
|
||||
filter_size=cfg['duration_predictor_filter_size'],
|
||||
dropout=cfg['dropout'])
|
||||
|
@ -53,30 +55,30 @@ class FastSpeech(dg.Layer):
|
|||
len_max_seq=cfg['max_seq_len'],
|
||||
n_layers=cfg['decoder_n_layer'],
|
||||
n_head=cfg['decoder_head'],
|
||||
d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_q=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_model=cfg['fs_hidden_size'],
|
||||
d_k=cfg['hidden_size'] // cfg['decoder_head'],
|
||||
d_q=cfg['hidden_size'] // cfg['decoder_head'],
|
||||
d_model=cfg['hidden_size'],
|
||||
d_inner=cfg['decoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
fft_conv1d_padding=cfg['fft_conv1d_padding'],
|
||||
dropout=0.1)
|
||||
self.weight = fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer())
|
||||
k = math.sqrt(1.0 / cfg['fs_hidden_size'])
|
||||
k = math.sqrt(1.0 / cfg['hidden_size'])
|
||||
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k))
|
||||
self.mel_linear = dg.Linear(
|
||||
cfg['fs_hidden_size'],
|
||||
cfg['audio']['num_mels'] * cfg['audio']['outputs_per_step'],
|
||||
cfg['hidden_size'],
|
||||
num_mels * cfg['outputs_per_step'],
|
||||
param_attr=self.weight,
|
||||
bias_attr=self.bias, )
|
||||
self.postnet = PostConvNet(
|
||||
n_mels=cfg['audio']['num_mels'],
|
||||
n_mels=num_mels,
|
||||
num_hidden=512,
|
||||
filter_size=5,
|
||||
padding=int(5 / 2),
|
||||
num_conv=5,
|
||||
outputs_per_step=cfg['audio']['outputs_per_step'],
|
||||
outputs_per_step=cfg['outputs_per_step'],
|
||||
use_cudnn=True,
|
||||
dropout=0.1,
|
||||
batchnorm_last=True)
|
||||
|
@ -84,11 +86,7 @@ class FastSpeech(dg.Layer):
|
|||
def forward(self,
|
||||
character,
|
||||
text_pos,
|
||||
enc_non_pad_mask,
|
||||
dec_non_pad_mask,
|
||||
mel_pos=None,
|
||||
enc_slf_attn_mask=None,
|
||||
dec_slf_attn_mask=None,
|
||||
length_target=None,
|
||||
alpha=1.0):
|
||||
"""
|
||||
|
@ -100,12 +98,6 @@ class FastSpeech(dg.Layer):
|
|||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
|
||||
dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
|
||||
the mask of input characters. Defaults to None.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
length_target (Variable, optional): shape(B, T_text), dtype int64,
|
||||
the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
|
||||
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
|
||||
|
@ -119,19 +111,12 @@ class FastSpeech(dg.Layer):
|
|||
dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, enc_slf_attn_list = self.encoder(
|
||||
character,
|
||||
text_pos,
|
||||
enc_non_pad_mask,
|
||||
slf_attn_mask=enc_slf_attn_mask)
|
||||
encoder_output, enc_slf_attn_list = self.encoder(character, text_pos)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
length_regulator_output, duration_predictor_output = self.length_regulator(
|
||||
encoder_output, target=length_target, alpha=alpha)
|
||||
decoder_output, dec_slf_attn_list = self.decoder(
|
||||
length_regulator_output,
|
||||
mel_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=dec_slf_attn_mask)
|
||||
length_regulator_output, mel_pos)
|
||||
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
@ -140,18 +125,8 @@ class FastSpeech(dg.Layer):
|
|||
else:
|
||||
length_regulator_output, decoder_pos = self.length_regulator(
|
||||
encoder_output, alpha=alpha)
|
||||
slf_attn_mask = get_triu_tensor(
|
||||
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
|
||||
slf_attn_mask = fluid.layers.cast(
|
||||
dg.to_variable(slf_attn_mask == 0), np.float32)
|
||||
slf_attn_mask = dg.to_variable(slf_attn_mask)
|
||||
dec_non_pad_mask = fluid.layers.unsqueeze(
|
||||
(decoder_pos != 0).astype(np.float32), [-1])
|
||||
decoder_output, _ = self.decoder(
|
||||
length_regulator_output,
|
||||
decoder_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=slf_attn_mask)
|
||||
decoder_output, _ = self.decoder(length_regulator_output,
|
||||
decoder_pos)
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
||||
|
|
|
@ -37,11 +37,10 @@ def score_F(attn):
|
|||
|
||||
|
||||
def compute_duration(attn, mel_lens):
|
||||
alignment = np.zeros([attn.shape[0], attn.shape[2]])
|
||||
mel_lens = mel_lens.numpy()
|
||||
for i in range(attn.shape[0]):
|
||||
for j in range(mel_lens[i]):
|
||||
max_index = np.argmax(attn[i, j])
|
||||
alignment[i, max_index] += 1
|
||||
alignment = np.zeros([attn.shape[2]])
|
||||
#for i in range(attn.shape[0]):
|
||||
for j in range(mel_lens):
|
||||
max_index = np.argmax(attn[0, j])
|
||||
alignment[max_index] += 1
|
||||
|
||||
return alignment
|
||||
|
|
|
@ -10,4 +10,6 @@
|
|||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# limitations under the License.
|
||||
from .transformer_tts import TransformerTTS
|
||||
from .vocoder import Vocoder
|
|
@ -22,14 +22,20 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
|||
|
||||
|
||||
class Decoder(dg.Layer):
|
||||
def __init__(self, num_hidden, config, num_head=4, n_layers=3):
|
||||
def __init__(self,
|
||||
num_hidden,
|
||||
num_mels=80,
|
||||
outputs_per_step=1,
|
||||
num_head=4,
|
||||
n_layers=3):
|
||||
"""Decoder layer of TransformerTTS.
|
||||
|
||||
Args:
|
||||
num_hidden (int): the number of source vocabulary.
|
||||
config: the yaml configs used in decoder.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 3.
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 3.
|
||||
"""
|
||||
super(Decoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
|
@ -51,7 +57,7 @@ class Decoder(dg.Layer):
|
|||
self.pos_inp),
|
||||
trainable=False))
|
||||
self.decoder_prenet = PreNet(
|
||||
input_size=config['audio']['num_mels'],
|
||||
input_size=num_mels,
|
||||
hidden_size=num_hidden * 2,
|
||||
output_size=num_hidden,
|
||||
dropout_rate=0.2)
|
||||
|
@ -85,7 +91,7 @@ class Decoder(dg.Layer):
|
|||
self.add_sublayer("ffns_{}".format(i), layer)
|
||||
self.mel_linear = dg.Linear(
|
||||
num_hidden,
|
||||
config['audio']['num_mels'] * config['audio']['outputs_per_step'],
|
||||
num_mels * outputs_per_step,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer()),
|
||||
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
|
@ -99,23 +105,15 @@ class Decoder(dg.Layer):
|
|||
low=-k, high=k)))
|
||||
|
||||
self.postconvnet = PostConvNet(
|
||||
config['audio']['num_mels'],
|
||||
config['hidden_size'],
|
||||
num_mels,
|
||||
num_hidden,
|
||||
filter_size=5,
|
||||
padding=4,
|
||||
num_conv=5,
|
||||
outputs_per_step=config['audio']['outputs_per_step'],
|
||||
outputs_per_step=outputs_per_step,
|
||||
use_cudnn=True)
|
||||
|
||||
def forward(self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
positional,
|
||||
mask,
|
||||
m_mask=None,
|
||||
m_self_mask=None,
|
||||
zero_mask=None):
|
||||
def forward(self, key, value, query, positional, c_mask):
|
||||
"""
|
||||
Compute decoder outputs.
|
||||
|
||||
|
@ -126,11 +124,7 @@ class Decoder(dg.Layer):
|
|||
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
||||
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
|
||||
m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
|
||||
m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
|
||||
zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
|
||||
|
||||
c_mask (Variable): shape(B, T_text, 1), dtype float32, query mask returned from encoder.
|
||||
Returns:
|
||||
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
||||
out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
|
||||
|
@ -142,14 +136,20 @@ class Decoder(dg.Layer):
|
|||
# get decoder mask with triangular matrix
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]])
|
||||
m_self_mask = layers.expand(m_self_mask,
|
||||
[self.num_head, 1, query.shape[1]])
|
||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
||||
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1])
|
||||
mask = get_dec_attn_key_pad_mask(positional, self.num_head,
|
||||
query.dtype)
|
||||
m_mask = get_non_pad_mask(positional, self.num_head, query.dtype)
|
||||
zero_mask = layers.cast(c_mask == 0, dtype=query.dtype) * -1e30
|
||||
zero_mask = layers.transpose(zero_mask, perm=[0, 2, 1])
|
||||
|
||||
else:
|
||||
m_mask, m_self_mask, zero_mask = None, None, None
|
||||
len_q = query.shape[1]
|
||||
mask = layers.triu(
|
||||
layers.ones(
|
||||
shape=[len_q, len_q], dtype=query.dtype),
|
||||
diagonal=1)
|
||||
mask = layers.cast(mask != 0, dtype=query.dtype) * -1e30
|
||||
m_mask, zero_mask = None, None
|
||||
|
||||
# Decoder pre-network
|
||||
query = self.decoder_prenet(query)
|
||||
|
@ -172,7 +172,7 @@ class Decoder(dg.Layer):
|
|||
for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers,
|
||||
self.ffns):
|
||||
query, attn_dec = selfattn(
|
||||
query, query, query, mask=mask, query_mask=m_self_mask)
|
||||
query, query, query, mask=mask, query_mask=m_mask)
|
||||
query, attn_dot = attn(
|
||||
key, value, query, mask=zero_mask, query_mask=m_mask)
|
||||
query = ffn(query)
|
||||
|
|
|
@ -26,8 +26,8 @@ class Encoder(dg.Layer):
|
|||
Args:
|
||||
embedding_size (int): the size of position embedding.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 3.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 3.
|
||||
"""
|
||||
super(Encoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
|
@ -64,7 +64,7 @@ class Encoder(dg.Layer):
|
|||
for i, layer in enumerate(self.ffns):
|
||||
self.add_sublayer("ffns_{}".format(i), layer)
|
||||
|
||||
def forward(self, x, positional, mask=None, query_mask=None):
|
||||
def forward(self, x, positional):
|
||||
"""
|
||||
Encode text sequence.
|
||||
|
||||
|
@ -72,24 +72,22 @@ class Encoder(dg.Layer):
|
|||
x (Variable): shape(B, T_text), dtype float32, the input character,
|
||||
where T_text means the timesteps of input text,
|
||||
positional (Variable): shape(B, T_text), dtype int64, the characters position.
|
||||
mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
|
||||
query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
|
||||
|
||||
Returns:
|
||||
x (Variable): shape(B, T_text, C), the encoder output.
|
||||
attentions (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
"""
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
seq_len_key = x.shape[1]
|
||||
query_mask = layers.expand(query_mask,
|
||||
[self.num_head, 1, seq_len_key])
|
||||
mask = layers.expand(mask, [self.num_head, 1, 1])
|
||||
else:
|
||||
query_mask, mask = None, None
|
||||
# Encoder pre_network
|
||||
x = self.encoder_prenet(x)
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
mask = get_attn_key_pad_mask(positional, self.num_head, x.dtype)
|
||||
query_mask = get_non_pad_mask(positional, self.num_head, x.dtype)
|
||||
|
||||
else:
|
||||
query_mask, mask = None, None
|
||||
|
||||
# Get positional encoding
|
||||
positional = self.pos_emb(positional)
|
||||
|
||||
|
@ -105,4 +103,4 @@ class Encoder(dg.Layer):
|
|||
x = ffn(x)
|
||||
attentions.append(attention)
|
||||
|
||||
return x, attentions
|
||||
return x, attentions, query_mask
|
||||
|
|
|
@ -18,28 +18,34 @@ from parakeet.models.transformer_tts.decoder import Decoder
|
|||
|
||||
|
||||
class TransformerTTS(dg.Layer):
|
||||
def __init__(self, config):
|
||||
def __init__(self,
|
||||
embedding_size,
|
||||
num_hidden,
|
||||
encoder_num_head=4,
|
||||
encoder_n_layers=3,
|
||||
n_mels=80,
|
||||
outputs_per_step=1,
|
||||
decoder_num_head=4,
|
||||
decoder_n_layers=3):
|
||||
"""TransformerTTS model.
|
||||
|
||||
Args:
|
||||
config: the yaml configs used in TransformerTTS model.
|
||||
embedding_size (int): the size of position embedding.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
encoder_num_head (int, optional): the head number of multihead attention in encoder. Defaults to 4.
|
||||
encoder_n_layers (int, optional): the layers number of multihead attention in encoder. Defaults to 3.
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
|
||||
decoder_num_head (int, optional): the head number of multihead attention in decoder. Defaults to 4.
|
||||
decoder_n_layers (int, optional): the layers number of multihead attention in decoder. Defaults to 3.
|
||||
"""
|
||||
super(TransformerTTS, self).__init__()
|
||||
self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
|
||||
self.decoder = Decoder(config['hidden_size'], config)
|
||||
self.config = config
|
||||
self.encoder = Encoder(embedding_size, num_hidden, encoder_num_head,
|
||||
encoder_n_layers)
|
||||
self.decoder = Decoder(num_hidden, n_mels, outputs_per_step,
|
||||
decoder_num_head, decoder_n_layers)
|
||||
|
||||
def forward(self,
|
||||
characters,
|
||||
mel_input,
|
||||
pos_text,
|
||||
pos_mel,
|
||||
dec_slf_mask,
|
||||
enc_slf_mask=None,
|
||||
enc_query_mask=None,
|
||||
enc_dec_mask=None,
|
||||
dec_query_slf_mask=None,
|
||||
dec_query_mask=None):
|
||||
def forward(self, characters, mel_input, pos_text, pos_mel):
|
||||
"""
|
||||
TransformerTTS network.
|
||||
|
||||
|
@ -49,13 +55,6 @@ class TransformerTTS(dg.Layer):
|
|||
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
|
||||
dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
||||
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
|
||||
enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
|
||||
enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
|
||||
dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
|
||||
dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
|
||||
enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
|
||||
|
||||
Returns:
|
||||
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
||||
|
@ -65,16 +64,8 @@ class TransformerTTS(dg.Layer):
|
|||
attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
key, attns_enc = self.encoder(
|
||||
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
||||
key, attns_enc, query_mask = self.encoder(characters, pos_text)
|
||||
|
||||
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(
|
||||
key,
|
||||
key,
|
||||
mel_input,
|
||||
pos_mel,
|
||||
mask=dec_slf_mask,
|
||||
zero_mask=enc_dec_mask,
|
||||
m_self_mask=dec_query_slf_mask,
|
||||
m_mask=dec_query_mask)
|
||||
key, key, mel_input, pos_mel, query_mask)
|
||||
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
||||
|
|
|
@ -50,41 +50,37 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
|||
return sinusoid_table
|
||||
|
||||
|
||||
def get_non_pad_mask(seq):
|
||||
mask = (seq != 0).astype(np.float32)
|
||||
mask = np.expand_dims(mask, axis=-1)
|
||||
def get_non_pad_mask(seq, num_head, dtype):
|
||||
mask = layers.cast(seq != 0, dtype=dtype)
|
||||
mask = layers.unsqueeze(mask, axes=[-1])
|
||||
mask = layers.expand(mask, [num_head, 1, 1])
|
||||
return mask
|
||||
|
||||
|
||||
def get_attn_key_pad_mask(seq_k):
|
||||
def get_attn_key_pad_mask(seq_k, num_head, dtype):
|
||||
''' For masking out the padding part of key sequence. '''
|
||||
# Expand to fit the shape of key query attention matrix.
|
||||
padding_mask = (seq_k != 0).astype(np.float32)
|
||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
||||
padding_mask = (
|
||||
padding_mask == 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
|
||||
padding_mask = layers.cast(seq_k == 0, dtype=dtype) * -1e30
|
||||
padding_mask = layers.unsqueeze(padding_mask, axes=[1])
|
||||
padding_mask = layers.expand(padding_mask, [num_head, 1, 1])
|
||||
return padding_mask
|
||||
|
||||
|
||||
def get_dec_attn_key_pad_mask(seq_k, seq_q):
|
||||
def get_dec_attn_key_pad_mask(seq_k, num_head, dtype):
|
||||
''' For masking out the padding part of key sequence. '''
|
||||
|
||||
# Expand to fit the shape of key query attention matrix.
|
||||
padding_mask = (seq_k == 0).astype(np.float32)
|
||||
padding_mask = np.expand_dims(padding_mask, axis=1)
|
||||
triu_tensor = get_triu_tensor(seq_q, seq_q)
|
||||
padding_mask = padding_mask + triu_tensor
|
||||
padding_mask = (
|
||||
padding_mask != 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
|
||||
return padding_mask
|
||||
|
||||
|
||||
def get_triu_tensor(seq_k, seq_q):
|
||||
''' For make a triu tensor '''
|
||||
padding_mask = layers.cast(seq_k == 0, dtype=dtype)
|
||||
padding_mask = layers.unsqueeze(padding_mask, axes=[1])
|
||||
len_k = seq_k.shape[1]
|
||||
len_q = seq_q.shape[1]
|
||||
triu_tensor = np.triu(np.ones([len_k, len_q]), 1)
|
||||
return triu_tensor
|
||||
triu = layers.triu(
|
||||
layers.ones(
|
||||
shape=[len_k, len_k], dtype=dtype), diagonal=1)
|
||||
padding_mask = padding_mask + triu
|
||||
padding_mask = layers.cast(
|
||||
padding_mask != 0, dtype=dtype) * -1e30 #* (-2**32 + 1)
|
||||
padding_mask = layers.expand(padding_mask, [num_head, 1, 1])
|
||||
return padding_mask
|
||||
|
||||
|
||||
def guided_attention(N, T, g=0.2):
|
||||
|
|
|
@ -19,22 +19,22 @@ from parakeet.models.transformer_tts.cbhg import CBHG
|
|||
|
||||
|
||||
class Vocoder(dg.Layer):
|
||||
def __init__(self, config, batch_size):
|
||||
def __init__(self, batch_size, hidden_size, num_mels=80, n_fft=2048):
|
||||
"""CBHG Network (mel -> linear)
|
||||
|
||||
Args:
|
||||
config: the yaml configs used in Vocoder model.
|
||||
batch_size (int): the batch size of input.
|
||||
hidden_size (int): the size of hidden layer in network.
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
n_fft (int, optional): length of the windowed signal after padding with zeros. Defaults to 2048.
|
||||
"""
|
||||
super(Vocoder, self).__init__()
|
||||
self.pre_proj = Conv1D(
|
||||
num_channels=config['audio']['num_mels'],
|
||||
num_filters=config['hidden_size'],
|
||||
filter_size=1)
|
||||
self.cbhg = CBHG(config['hidden_size'], batch_size)
|
||||
num_channels=num_mels, num_filters=hidden_size, filter_size=1)
|
||||
self.cbhg = CBHG(hidden_size, batch_size)
|
||||
self.post_proj = Conv1D(
|
||||
num_channels=config['hidden_size'],
|
||||
num_filters=(config['audio']['n_fft'] // 2) + 1,
|
||||
num_channels=hidden_size,
|
||||
num_filters=(n_fft // 2) + 1,
|
||||
filter_size=1)
|
||||
|
||||
def forward(self, mel):
|
||||
|
|
|
@ -125,6 +125,7 @@ def load_parameters(model,
|
|||
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# cast to desired data type, for mixed-precision training/inference.
|
||||
for k, v in model_dict.items():
|
||||
if k in state_dict and convert_np_dtype(v.dtype) != state_dict[
|
||||
|
@ -132,6 +133,7 @@ def load_parameters(model,
|
|||
model_dict[k] = v.astype(state_dict[k].numpy().dtype)
|
||||
|
||||
model.set_dict(model_dict)
|
||||
|
||||
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
|
||||
local_rank, checkpoint_path))
|
||||
|
||||
|
|
Loading…
Reference in New Issue