add aishell3 example
This commit is contained in:
parent
c4615e3bba
commit
3d10fec409
|
@ -0,0 +1,74 @@
|
||||||
|
# FastSpeech2 with AISHELL-3
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems.
|
||||||
|
We use AISHELL-3 to train a multi-speaker fastspeech2 model here.
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
### Download and Extract the datasaet.
|
||||||
|
Download AISHELL-3.
|
||||||
|
```bash
|
||||||
|
wget https://www.openslr.org/resources/93/data_aishell3.tgz
|
||||||
|
```
|
||||||
|
Extract AISHELL.
|
||||||
|
```bash
|
||||||
|
mkdir data_aishell3
|
||||||
|
tar zxvf data_aishell3.tgz -C data_aishell3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Get MFA result of BZNSYP and Extract it.
|
||||||
|
|
||||||
|
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
|
||||||
|
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
|
||||||
|
|
||||||
|
### Preprocess the dataset.
|
||||||
|
|
||||||
|
Assume the path to the dataset is `~/datasets/data_aishell3`.
|
||||||
|
Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`.
|
||||||
|
Run the command below to preprocess the dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./preprocess.sh
|
||||||
|
```
|
||||||
|
## Train the model
|
||||||
|
```bash
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
If you want to train fastspeech2 with cpu, please add `--device=cpu` arguments for `python3 train.py` in `run.sh`.
|
||||||
|
## Synthesize
|
||||||
|
We use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
|
||||||
|
Download pretrained parallel wavegan model (Trained with baker) from [parallel_wavegan_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_0.4.zip) and unzip it.
|
||||||
|
```bash
|
||||||
|
unzip parallel_wavegan_baker_ckpt_0.4.zip
|
||||||
|
```
|
||||||
|
`synthesize.sh` can synthesize waveform from `metadata.jsonl`.
|
||||||
|
`synthesize_e2e.sh` can synthesize waveform from text list.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./synthesize.sh
|
||||||
|
```
|
||||||
|
or
|
||||||
|
```bash
|
||||||
|
./synthesize_e2e.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
You can see the bash files for more datails of input parameters.
|
||||||
|
|
||||||
|
## Pretrained Model
|
||||||
|
Pretrained Model with no sil in the edge of audios can be downloaded here. [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip)
|
||||||
|
|
||||||
|
Then, you can use the following scripts to synthesize for `sentences.txt` using pretrained fastspeech2 model.
|
||||||
|
```bash
|
||||||
|
python3 synthesize_e2e.py \
|
||||||
|
--fastspeech2-config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \
|
||||||
|
--fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \
|
||||||
|
--fastspeech2-stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \
|
||||||
|
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||||
|
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||||
|
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||||
|
--text=sentences.txt \
|
||||||
|
--output-dir=exp/debug/test_e2e \
|
||||||
|
--device="gpu" \
|
||||||
|
--phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
|
||||||
|
```
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from parakeet.data.batch import batch_sequences
|
||||||
|
|
||||||
|
|
||||||
|
def collate_aishell3_examples(examples):
|
||||||
|
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"]
|
||||||
|
text = [np.array(item["text"], dtype=np.int64) for item in examples]
|
||||||
|
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
|
||||||
|
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
|
||||||
|
energy = [np.array(item["energy"], dtype=np.float32) for item in examples]
|
||||||
|
durations = [
|
||||||
|
np.array(item["durations"], dtype=np.int64) for item in examples
|
||||||
|
]
|
||||||
|
text_lengths = np.array([item["text_lengths"] for item in examples])
|
||||||
|
speech_lengths = np.array([item["speech_lengths"] for item in examples])
|
||||||
|
spk_id = np.array([item["spk_id"] for item in examples])
|
||||||
|
|
||||||
|
text = batch_sequences(text)
|
||||||
|
pitch = batch_sequences(pitch)
|
||||||
|
speech = batch_sequences(speech)
|
||||||
|
durations = batch_sequences(durations)
|
||||||
|
energy = batch_sequences(energy)
|
||||||
|
|
||||||
|
# convert each batch to paddle.Tensor
|
||||||
|
text = paddle.to_tensor(text)
|
||||||
|
pitch = paddle.to_tensor(pitch)
|
||||||
|
speech = paddle.to_tensor(speech)
|
||||||
|
durations = paddle.to_tensor(durations)
|
||||||
|
energy = paddle.to_tensor(energy)
|
||||||
|
text_lengths = paddle.to_tensor(text_lengths)
|
||||||
|
speech_lengths = paddle.to_tensor(speech_lengths)
|
||||||
|
spk_id = paddle.to_tensor(spk_id)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"text": text,
|
||||||
|
"text_lengths": text_lengths,
|
||||||
|
"durations": durations,
|
||||||
|
"speech": speech,
|
||||||
|
"speech_lengths": speech_lengths,
|
||||||
|
"pitch": pitch,
|
||||||
|
"energy": energy,
|
||||||
|
"spk_id": spk_id
|
||||||
|
}
|
||||||
|
return batch
|
|
@ -0,0 +1,106 @@
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
|
||||||
|
fs: 24000 # sr
|
||||||
|
n_fft: 2048 # FFT size.
|
||||||
|
n_shift: 300 # Hop size.
|
||||||
|
win_length: 1200 # Window length.
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
|
||||||
|
# Only used for feats_type != raw
|
||||||
|
|
||||||
|
fmin: 80 # Minimum frequency of Mel basis.
|
||||||
|
fmax: 7600 # Maximum frequency of Mel basis.
|
||||||
|
n_mels: 80 # The number of mel basis.
|
||||||
|
|
||||||
|
# Only used for the model using pitch features (e.g. FastSpeech2)
|
||||||
|
f0min: 80 # Maximum f0 for pitch extraction.
|
||||||
|
f0max: 400 # Minimum f0 for pitch extraction.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
model:
|
||||||
|
adim: 384 # attention dimension
|
||||||
|
aheads: 2 # number of attention heads
|
||||||
|
elayers: 4 # number of encoder layers
|
||||||
|
eunits: 1536 # number of encoder ff units
|
||||||
|
dlayers: 4 # number of decoder layers
|
||||||
|
dunits: 1536 # number of decoder ff units
|
||||||
|
positionwise_layer_type: conv1d # type of position-wise layer
|
||||||
|
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
|
||||||
|
duration_predictor_layers: 2 # number of layers of duration predictor
|
||||||
|
duration_predictor_chans: 256 # number of channels of duration predictor
|
||||||
|
duration_predictor_kernel_size: 3 # filter size of duration predictor
|
||||||
|
postnet_layers: 5 # number of layers of postnset
|
||||||
|
postnet_filts: 5 # filter size of conv layers in postnet
|
||||||
|
postnet_chans: 256 # number of channels of conv layers in postnet
|
||||||
|
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||||
|
use_scaled_pos_enc: True # whether to use scaled positional encoding
|
||||||
|
encoder_normalize_before: True # whether to perform layer normalization before the input
|
||||||
|
decoder_normalize_before: True # whether to perform layer normalization before the input
|
||||||
|
reduction_factor: 1 # reduction factor
|
||||||
|
init_type: xavier_uniform # initialization type
|
||||||
|
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
|
||||||
|
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
|
||||||
|
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
|
||||||
|
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
|
||||||
|
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
|
||||||
|
transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer
|
||||||
|
transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding
|
||||||
|
transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer
|
||||||
|
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
|
||||||
|
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
|
||||||
|
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
|
||||||
|
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
|
||||||
|
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
|
||||||
|
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
|
||||||
|
stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder
|
||||||
|
energy_predictor_layers: 2 # number of conv layers in energy predictor
|
||||||
|
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
|
||||||
|
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
|
||||||
|
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
|
||||||
|
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
|
||||||
|
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
|
||||||
|
stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder
|
||||||
|
spk_embed_dim: 256 # speaker embedding dimension
|
||||||
|
spk_embed_integration_type: concat # speaker embedding integration type
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# UPDATER SETTING #
|
||||||
|
###########################################################
|
||||||
|
updater:
|
||||||
|
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER SETTING #
|
||||||
|
###########################################################
|
||||||
|
optimizer:
|
||||||
|
optim: adam # optimizer type
|
||||||
|
learning_rate: 0.001 # learning rate
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# TRAINING SETTING #
|
||||||
|
###########################################################
|
||||||
|
max_epoch: 200
|
||||||
|
num_snapshots: 5
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
seed: 10086
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode as Configuration
|
||||||
|
|
||||||
|
config_path = (Path(__file__).parent / "conf" / "default.yaml").resolve()
|
||||||
|
|
||||||
|
with open(config_path, 'rt') as f:
|
||||||
|
_C = yaml.safe_load(f)
|
||||||
|
_C = Configuration(_C)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cfg_default():
|
||||||
|
config = _C.clone()
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
print(get_cfg_default())
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
||||||
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||||
|
|
||||||
|
|
||||||
|
class FastSpeech2Updater(StandardUpdater):
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
dataloader,
|
||||||
|
init_state=None,
|
||||||
|
use_masking=False,
|
||||||
|
use_weighted_masking=False):
|
||||||
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||||
|
self.use_masking = use_masking
|
||||||
|
self.use_weighted_masking = use_weighted_masking
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||||
|
text=batch["text"],
|
||||||
|
text_lengths=batch["text_lengths"],
|
||||||
|
speech=batch["speech"],
|
||||||
|
speech_lengths=batch["speech_lengths"],
|
||||||
|
durations=batch["durations"],
|
||||||
|
pitch=batch["pitch"],
|
||||||
|
energy=batch["energy"],
|
||||||
|
spk_id=batch["spk_id"], )
|
||||||
|
|
||||||
|
criterion = FastSpeech2Loss(
|
||||||
|
use_masking=self.use_masking,
|
||||||
|
use_weighted_masking=self.use_weighted_masking)
|
||||||
|
|
||||||
|
l1_loss, duration_loss, pitch_loss, energy_loss = criterion(
|
||||||
|
after_outs=after_outs,
|
||||||
|
before_outs=before_outs,
|
||||||
|
d_outs=d_outs,
|
||||||
|
p_outs=p_outs,
|
||||||
|
e_outs=e_outs,
|
||||||
|
ys=ys,
|
||||||
|
ds=batch["durations"],
|
||||||
|
ps=batch["pitch"],
|
||||||
|
es=batch["energy"],
|
||||||
|
ilens=batch["text_lengths"],
|
||||||
|
olens=olens)
|
||||||
|
|
||||||
|
loss = l1_loss + duration_loss + pitch_loss + energy_loss
|
||||||
|
|
||||||
|
optimizer = self.optimizer
|
||||||
|
optimizer.clear_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
report("train/loss", float(loss))
|
||||||
|
report("train/l1_loss", float(l1_loss))
|
||||||
|
report("train/duration_loss", float(duration_loss))
|
||||||
|
report("train/pitch_loss", float(pitch_loss))
|
||||||
|
report("train/energy_loss", float(energy_loss))
|
||||||
|
|
||||||
|
|
||||||
|
class FastSpeech2Evaluator(StandardEvaluator):
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
dataloader,
|
||||||
|
use_masking=False,
|
||||||
|
use_weighted_masking=False):
|
||||||
|
super().__init__(model, dataloader)
|
||||||
|
self.use_masking = use_masking
|
||||||
|
self.use_weighted_masking = use_weighted_masking
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||||
|
text=batch["text"],
|
||||||
|
text_lengths=batch["text_lengths"],
|
||||||
|
speech=batch["speech"],
|
||||||
|
speech_lengths=batch["speech_lengths"],
|
||||||
|
durations=batch["durations"],
|
||||||
|
pitch=batch["pitch"],
|
||||||
|
energy=batch["energy"],
|
||||||
|
spk_id=batch["spk_id"], )
|
||||||
|
|
||||||
|
criterion = FastSpeech2Loss(
|
||||||
|
use_masking=self.use_masking,
|
||||||
|
use_weighted_masking=self.use_weighted_masking)
|
||||||
|
l1_loss, duration_loss, pitch_loss, energy_loss = criterion(
|
||||||
|
after_outs=after_outs,
|
||||||
|
before_outs=before_outs,
|
||||||
|
d_outs=d_outs,
|
||||||
|
p_outs=p_outs,
|
||||||
|
e_outs=e_outs,
|
||||||
|
ys=ys,
|
||||||
|
ds=batch["durations"],
|
||||||
|
ps=batch["pitch"],
|
||||||
|
es=batch["energy"],
|
||||||
|
ilens=batch["text_lengths"],
|
||||||
|
olens=olens, )
|
||||||
|
loss = l1_loss + duration_loss + pitch_loss + energy_loss
|
||||||
|
|
||||||
|
report("eval/loss", float(loss))
|
||||||
|
report("eval/l1_loss", float(l1_loss))
|
||||||
|
report("eval/duration_loss", float(duration_loss))
|
||||||
|
report("eval/pitch_loss", float(pitch_loss))
|
||||||
|
report("eval/energy_loss", float(energy_loss))
|
|
@ -0,0 +1,78 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
fs=24000
|
||||||
|
n_shift=300
|
||||||
|
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# get durations from MFA's result
|
||||||
|
echo "Generate durations.txt from MFA results ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./aishell3_alignment_tone \
|
||||||
|
--output durations.txt \
|
||||||
|
--sample-rate=${fs} \
|
||||||
|
--n-shift=${n_shift}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# extract features
|
||||||
|
echo "Extract features ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_preprocess.py \
|
||||||
|
--dataset=aishell3 \
|
||||||
|
--rootdir=~/datasets/data_aishell3/ \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config-path=conf/default.yaml \
|
||||||
|
--num-cpu=8 \
|
||||||
|
--cut-sil=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="speech"
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="pitch"
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="energy"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# normalize and covert phone to id, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
fi
|
|
@ -0,0 +1,163 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||||
|
from parakeet.modules.normalizer import ZScore
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args, fastspeech2_config, pwg_config):
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
# construct dataset for evaluation
|
||||||
|
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||||
|
test_metadata = list(reader)
|
||||||
|
test_dataset = DataTable(
|
||||||
|
data=test_metadata, fields=["utt_id", "text", "spk_id"])
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
|
||||||
|
with open(args.speaker_dict, 'rt') as f:
|
||||||
|
spk_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
num_speakers = len(spk_id)
|
||||||
|
print("num_speakers:", num_speakers)
|
||||||
|
|
||||||
|
odim = fastspeech2_config.n_mels
|
||||||
|
model = FastSpeech2(
|
||||||
|
idim=vocab_size,
|
||||||
|
odim=odim,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
**fastspeech2_config["model"])
|
||||||
|
|
||||||
|
model.set_state_dict(
|
||||||
|
paddle.load(args.fastspeech2_checkpoint)["main_params"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
vocoder = PWGGenerator(**pwg_config["generator_params"])
|
||||||
|
vocoder.set_state_dict(paddle.load(args.pwg_params))
|
||||||
|
vocoder.remove_weight_norm()
|
||||||
|
vocoder.eval()
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
stat = np.load(args.fastspeech2_stat)
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
fastspeech2_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
stat = np.load(args.pwg_stat)
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
pwg_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
fastspeech2_inferencce = FastSpeech2Inference(fastspeech2_normalizer, model)
|
||||||
|
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for datum in test_dataset:
|
||||||
|
utt_id = datum["utt_id"]
|
||||||
|
text = paddle.to_tensor(datum["text"])
|
||||||
|
spk_id = paddle.to_tensor(datum["spk_id"])
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav = pwg_inference(fastspeech2_inferencce(text, spk_id=spk_id))
|
||||||
|
sf.write(
|
||||||
|
str(output_dir / (utt_id + ".wav")),
|
||||||
|
wav.numpy(),
|
||||||
|
samplerate=fastspeech2_config.fs)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-config",
|
||||||
|
type=str,
|
||||||
|
help="config file to overwrite default config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-checkpoint",
|
||||||
|
type=str,
|
||||||
|
help="fastspeech2 checkpoint to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-config",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-params",
|
||||||
|
type=str,
|
||||||
|
help="parallel wavegan generator parameters to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones-dict",
|
||||||
|
type=str,
|
||||||
|
default="phone_id_map.txt",
|
||||||
|
help="phone vocabulary file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--speaker-dict",
|
||||||
|
type=str,
|
||||||
|
default="speaker_id_map.txt ",
|
||||||
|
help="speaker id map file.")
|
||||||
|
parser.add_argument("--test-metadata", type=str, help="test metadata.")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="gpu", help="device type to use.")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.fastspeech2_config) as f:
|
||||||
|
fastspeech2_config = CfgNode(yaml.safe_load(f))
|
||||||
|
with open(args.pwg_config) as f:
|
||||||
|
pwg_config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(fastspeech2_config)
|
||||||
|
print(pwg_config)
|
||||||
|
|
||||||
|
evaluate(args, fastspeech2_config, pwg_config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,15 @@
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python3 synthesize.py \
|
||||||
|
--fastspeech2-config=conf/default.yaml \
|
||||||
|
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \
|
||||||
|
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||||
|
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||||
|
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||||
|
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||||
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output-dir=exp/debug/test \
|
||||||
|
--device="gpu" \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
|
@ -0,0 +1,176 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||||
|
from parakeet.modules.normalizer import ZScore
|
||||||
|
|
||||||
|
from frontend import Frontend
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(args, fastspeech2_config, pwg_config):
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
# construct dataset for evaluation
|
||||||
|
sentences = []
|
||||||
|
with open(args.text, 'rt') as f:
|
||||||
|
for line in f:
|
||||||
|
utt_id, sentence = line.strip().split()
|
||||||
|
sentences.append((utt_id, sentence))
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
with open(args.speaker_dict, 'rt') as f:
|
||||||
|
spk_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
num_speakers = len(spk_id)
|
||||||
|
print("num_speakers:", num_speakers)
|
||||||
|
|
||||||
|
odim = fastspeech2_config.n_mels
|
||||||
|
model = FastSpeech2(
|
||||||
|
idim=vocab_size,
|
||||||
|
odim=odim,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
**fastspeech2_config["model"])
|
||||||
|
|
||||||
|
model.set_state_dict(
|
||||||
|
paddle.load(args.fastspeech2_checkpoint)["main_params"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
vocoder = PWGGenerator(**pwg_config["generator_params"])
|
||||||
|
vocoder.set_state_dict(paddle.load(args.pwg_params))
|
||||||
|
vocoder.remove_weight_norm()
|
||||||
|
vocoder.eval()
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
frontend = Frontend(args.phones_dict)
|
||||||
|
print("frontend done!")
|
||||||
|
|
||||||
|
stat = np.load(args.fastspeech2_stat)
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
fastspeech2_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
stat = np.load(args.pwg_stat)
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
pwg_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
|
||||||
|
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# only test the number 0 speaker
|
||||||
|
spk_id = 0
|
||||||
|
for utt_id, sentence in sentences:
|
||||||
|
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
|
flags = 0
|
||||||
|
for part_phone_ids in phone_ids:
|
||||||
|
with paddle.no_grad():
|
||||||
|
mel = fastspeech2_inference(
|
||||||
|
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
|
||||||
|
temp_wav = pwg_inference(mel)
|
||||||
|
if flags == 0:
|
||||||
|
wav = temp_wav
|
||||||
|
flags = 1
|
||||||
|
else:
|
||||||
|
wav = paddle.concat([wav, temp_wav])
|
||||||
|
sf.write(
|
||||||
|
str(output_dir / (str(spk_id) + "_" + utt_id + ".wav")),
|
||||||
|
wav.numpy(),
|
||||||
|
samplerate=fastspeech2_config.fs)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-config",
|
||||||
|
type=str,
|
||||||
|
help="fastspeech2 config file to overwrite default config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-checkpoint",
|
||||||
|
type=str,
|
||||||
|
help="fastspeech2 checkpoint to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-config",
|
||||||
|
type=str,
|
||||||
|
help="parallel wavegan config file to overwrite default config.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-params",
|
||||||
|
type=str,
|
||||||
|
help="parallel wavegan generator parameters to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones-dict",
|
||||||
|
type=str,
|
||||||
|
default="phone_id_map.txt",
|
||||||
|
help="phone vocabulary file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--speaker-dict",
|
||||||
|
type=str,
|
||||||
|
default="speaker_id_map.txt ",
|
||||||
|
help="speaker id map file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
help="text to synthesize, a 'utt_id sentence' pair per line.")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="gpu", help="device type to use.")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.fastspeech2_config) as f:
|
||||||
|
fastspeech2_config = CfgNode(yaml.safe_load(f))
|
||||||
|
with open(args.pwg_config) as f:
|
||||||
|
pwg_config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(fastspeech2_config)
|
||||||
|
print(pwg_config)
|
||||||
|
|
||||||
|
evaluate(args, fastspeech2_config, pwg_config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,15 @@
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python3 synthesize_e2e.py \
|
||||||
|
--fastspeech2-config=conf/default.yaml \
|
||||||
|
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \
|
||||||
|
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||||
|
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||||
|
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||||
|
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||||
|
--text=../sentences.txt \
|
||||||
|
--output-dir=exp/debug/test_e2e \
|
||||||
|
--device="gpu" \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
|
@ -0,0 +1,226 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import DataParallel
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.io import DataLoader, DistributedBatchSampler
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.models.fastspeech2 import FastSpeech2
|
||||||
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
from parakeet.training.extensions.visualizer import VisualDL
|
||||||
|
from parakeet.training.seeding import seed_everything
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from visualdl import LogWriter
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from batch_fn import collate_aishell3_examples
|
||||||
|
from config import get_cfg_default
|
||||||
|
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
||||||
|
|
||||||
|
optim_classes = dict(
|
||||||
|
adadelta=paddle.optimizer.Adadelta,
|
||||||
|
adagrad=paddle.optimizer.Adagrad,
|
||||||
|
adam=paddle.optimizer.Adam,
|
||||||
|
adamax=paddle.optimizer.Adamax,
|
||||||
|
adamw=paddle.optimizer.AdamW,
|
||||||
|
lamb=paddle.optimizer.Lamb,
|
||||||
|
momentum=paddle.optimizer.Momentum,
|
||||||
|
rmsprop=paddle.optimizer.RMSProp,
|
||||||
|
sgd=paddle.optimizer.SGD, )
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizers(model: nn.Layer, optim='adadelta',
|
||||||
|
learning_rate=0.01) -> paddle.optimizer:
|
||||||
|
optim_class = optim_classes.get(optim)
|
||||||
|
if optim_class is None:
|
||||||
|
raise ValueError(f"must be one of {list(optim_classes)}: {optim}")
|
||||||
|
else:
|
||||||
|
optim = optim_class(
|
||||||
|
parameters=model.parameters(), learning_rate=learning_rate)
|
||||||
|
|
||||||
|
optimizers = optim
|
||||||
|
return optimizers
|
||||||
|
|
||||||
|
|
||||||
|
def train_sp(args, config):
|
||||||
|
# decides device type and whether to run in parallel
|
||||||
|
# setup running environment correctly
|
||||||
|
if not paddle.is_compiled_with_cuda():
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
else:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
world_size = paddle.distributed.get_world_size()
|
||||||
|
if world_size > 1:
|
||||||
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
# construct dataset for training and validation
|
||||||
|
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||||
|
train_metadata = list(reader)
|
||||||
|
train_dataset = DataTable(
|
||||||
|
data=train_metadata,
|
||||||
|
fields=[
|
||||||
|
"text", "text_lengths", "speech", "speech_lengths", "durations",
|
||||||
|
"pitch", "energy", "spk_id"
|
||||||
|
],
|
||||||
|
converters={"speech": np.load,
|
||||||
|
"pitch": np.load,
|
||||||
|
"energy": np.load}, )
|
||||||
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
|
dev_metadata = list(reader)
|
||||||
|
dev_dataset = DataTable(
|
||||||
|
data=dev_metadata,
|
||||||
|
fields=[
|
||||||
|
"text", "text_lengths", "speech", "speech_lengths", "durations",
|
||||||
|
"pitch", "energy", "spk_id"
|
||||||
|
],
|
||||||
|
converters={"speech": np.load,
|
||||||
|
"pitch": np.load,
|
||||||
|
"energy": np.load}, )
|
||||||
|
|
||||||
|
# collate function and dataloader
|
||||||
|
train_sampler = DistributedBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
|
||||||
|
print("samplers done!")
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=collate_aishell3_examples,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
|
dev_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
collate_fn=collate_aishell3_examples,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
print("dataloaders done!")
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
|
||||||
|
with open(args.speaker_dict, 'rt') as f:
|
||||||
|
spk_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
num_speakers = len(spk_id)
|
||||||
|
print("num_speakers:", num_speakers)
|
||||||
|
|
||||||
|
odim = config.n_mels
|
||||||
|
model = FastSpeech2(
|
||||||
|
idim=vocab_size,
|
||||||
|
odim=odim,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
**config["model"])
|
||||||
|
if world_size > 1:
|
||||||
|
model = DataParallel(model)
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
optimizer = build_optimizers(model, **config["optimizer"])
|
||||||
|
print("optimizer done!")
|
||||||
|
|
||||||
|
updater = FastSpeech2Updater(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
**config["updater"])
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||||
|
|
||||||
|
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||||
|
writer = LogWriter(str(output_dir))
|
||||||
|
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||||
|
print(trainer.extensions)
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(description="Train a FastSpeech2 "
|
||||||
|
"model with Baker Mandrin TTS dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config.")
|
||||||
|
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||||
|
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="gpu", help="device type to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nprocs", type=int, default=1, help="number of processes.")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones-dict",
|
||||||
|
type=str,
|
||||||
|
default="phone_id_map.txt ",
|
||||||
|
help="phone vocabulary file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--speaker-dict",
|
||||||
|
type=str,
|
||||||
|
default="speaker_id_map.txt ",
|
||||||
|
help="speaker id map file.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.device == "cpu" and args.nprocs > 1:
|
||||||
|
raise RuntimeError("Multiprocess training on CPU is not supported.")
|
||||||
|
config = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# dispatch
|
||||||
|
if args.nprocs > 1:
|
||||||
|
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
|
||||||
|
else:
|
||||||
|
train_sp(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,18 +1,78 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# get durations from MFA's result
|
stage=0
|
||||||
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
|
stop_stage=100
|
||||||
|
fs=24000
|
||||||
|
n_shift=300
|
||||||
|
|
||||||
# extract features
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-file durations.txt --num-cpu 4 --cut-sil True
|
|
||||||
|
|
||||||
# # get features' stats(mean and std)
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
|
# get durations from MFA's result
|
||||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="pitch"
|
echo "Generate durations.txt from MFA results ..."
|
||||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="energy"
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./baker_alignment_tone \
|
||||||
|
--output=durations.txt \
|
||||||
|
--sample-rate=${fs} \
|
||||||
|
--n-shift=${n_shift}
|
||||||
|
fi
|
||||||
|
|
||||||
# normalize and covert phone to id, dev and test should use train's stats
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
python3 normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones-dict dump/phone_id_map.txt
|
# extract features
|
||||||
python3 normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones-dict dump/phone_id_map.txt
|
echo "Extract features ..."
|
||||||
python3 normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones-dict dump/phone_id_map.txt
|
python3 ${MAIN_ROOT}/utils/fastspeech2_preprocess.py \
|
||||||
|
--dataset=baker \
|
||||||
|
--rootdir=~/datasets/BZNSYP/ \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config-path=conf/default.yaml \
|
||||||
|
--num-cpu=8 \
|
||||||
|
--cut-sil=True
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="speech"
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="pitch"
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="energy"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# normalize and covert phone to id, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/fastspeech2_normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
fi
|
||||||
|
|
|
@ -8,7 +8,7 @@ python3 synthesize_e2e.py \
|
||||||
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||||
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||||
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||||
--text=sentences.txt \
|
--text=../sentences.txt \
|
||||||
--output-dir=exp/debug/test_e2e \
|
--output-dir=exp/debug/test_e2e \
|
||||||
--device="gpu" \
|
--device="gpu" \
|
||||||
--phones-dict=dump/phone_id_map.txt
|
--phones-dict=dump/phone_id_map.txt
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Script to reorganize AISHELL-3 dataset so as to use Montreal Force
|
||||||
|
Aligner to align transcription and audio.
|
||||||
|
|
||||||
|
Please refer to https://montreal-forced-aligner.readthedocs.io/en/latest/data_prep.html
|
||||||
|
for more details about Montreal Force Aligner's requirements on cotpus.
|
||||||
|
|
||||||
|
For scripts to reorganize other corpus, please refer to
|
||||||
|
https://github.com/MontrealCorpusTools/MFA-reorganization-scripts
|
||||||
|
for more details.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def link_wav(root_dir: Union[str, Path], output_dir: Union[str, Path]):
|
||||||
|
for sub_set in {'train', 'test'}:
|
||||||
|
wav_dir = root_dir / sub_set / 'wav'
|
||||||
|
new_dir = output_dir / sub_set
|
||||||
|
new_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for spk_dir in os.listdir(wav_dir):
|
||||||
|
sub_dir = wav_dir / spk_dir
|
||||||
|
new_sub_dir = new_dir / spk_dir
|
||||||
|
os.symlink(sub_dir, new_sub_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def write_lab(root_dir: Union[str, Path],
|
||||||
|
output_dir: Union[str, Path],
|
||||||
|
script_type='pinyin'):
|
||||||
|
for sub_set in {'train', 'test'}:
|
||||||
|
text_path = root_dir / sub_set / 'content.txt'
|
||||||
|
new_dir = output_dir / sub_set
|
||||||
|
|
||||||
|
with open(text_path, 'r') as rf:
|
||||||
|
for line in rf:
|
||||||
|
wav_id, context = line.strip().split('\t')
|
||||||
|
spk_id = wav_id[:7]
|
||||||
|
transcript_name = wav_id.split('.')[0] + '.lab'
|
||||||
|
transcript_path = new_dir / spk_id / transcript_name
|
||||||
|
context_list = context.split()
|
||||||
|
word_list = context_list[0:-1:2]
|
||||||
|
pinyin_list = context_list[1::2]
|
||||||
|
wf = open(transcript_path, 'w')
|
||||||
|
if script_type == 'word':
|
||||||
|
# add space between chinese char
|
||||||
|
new_context = ' '.join(word_list)
|
||||||
|
elif script_type == 'pinyin':
|
||||||
|
new_context = ' '.join(pinyin_list)
|
||||||
|
wf.write(new_context + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def reorganize_aishell3(root_dir: Union[str, Path],
|
||||||
|
output_dir: Union[str, Path]=None,
|
||||||
|
script_type='pinyin'):
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
link_wav(root_dir, output_dir)
|
||||||
|
write_lab(root_dir, output_dir, script_type)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Reorganize AISHELL-3 dataset for MFA")
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-dir", type=str, default="", help="path to AISHELL-3 dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
help="path to save outputs (audio and transcriptions)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--script-type",
|
||||||
|
type=str,
|
||||||
|
default="pinyin",
|
||||||
|
help="type of lab ('word'/'pinyin')")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
root_dir = Path(args.root_dir).expanduser()
|
||||||
|
output_dir = Path(args.output_dir).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
reorganize_aishell3(root_dir, output_dir, args.script_type)
|
|
@ -21,13 +21,12 @@ For scripts to reorganize other corpus, please refer to
|
||||||
https://github.com/MontrealCorpusTools/MFA-reorganization-scripts
|
https://github.com/MontrealCorpusTools/MFA-reorganization-scripts
|
||||||
for more details.
|
for more details.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import argparse
|
|
||||||
from typing import Union
|
|
||||||
from pathlib import Path
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
@ -103,7 +102,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-dir",
|
"--output-dir",
|
||||||
type=str,
|
type=str,
|
||||||
help="path to save outputs(audio and transcriptions)")
|
help="path to save outputs (audio and transcriptions)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resample-audio",
|
"--resample-audio",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
|
@ -12,10 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Fastspeech2 related modules for paddle"""
|
"""Fastspeech2 related modules for paddle"""
|
||||||
|
from typing import Dict
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from typeguard import check_argument_types
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
@ -92,6 +94,14 @@ class FastSpeech2(nn.Layer):
|
||||||
pitch_embed_kernel_size: int=9,
|
pitch_embed_kernel_size: int=9,
|
||||||
pitch_embed_dropout: float=0.5,
|
pitch_embed_dropout: float=0.5,
|
||||||
stop_gradient_from_pitch_predictor: bool=False,
|
stop_gradient_from_pitch_predictor: bool=False,
|
||||||
|
# spk emb
|
||||||
|
num_speakers: int=None,
|
||||||
|
spk_embed_dim: int=None,
|
||||||
|
spk_embed_integration_type: str="add",
|
||||||
|
# tone emb
|
||||||
|
num_tones: int=None,
|
||||||
|
tone_embed_dim: int=None,
|
||||||
|
tone_embed_integration_type: str="add",
|
||||||
# training related
|
# training related
|
||||||
transformer_enc_dropout_rate: float=0.1,
|
transformer_enc_dropout_rate: float=0.1,
|
||||||
transformer_enc_positional_dropout_rate: float=0.1,
|
transformer_enc_positional_dropout_rate: float=0.1,
|
||||||
|
@ -121,12 +131,32 @@ class FastSpeech2(nn.Layer):
|
||||||
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
||||||
self.use_scaled_pos_enc = use_scaled_pos_enc
|
self.use_scaled_pos_enc = use_scaled_pos_enc
|
||||||
|
|
||||||
|
self.spk_embed_dim = spk_embed_dim
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
self.spk_embed_integration_type = spk_embed_integration_type
|
||||||
|
|
||||||
|
self.tone_embed_dim = tone_embed_dim
|
||||||
|
if self.tone_embed_dim is not None:
|
||||||
|
self.tone_embed_integration_type = tone_embed_integration_type
|
||||||
|
|
||||||
# use idx 0 as padding idx
|
# use idx 0 as padding idx
|
||||||
self.padding_idx = 0
|
self.padding_idx = 0
|
||||||
|
|
||||||
# initialize parameters
|
# initialize parameters
|
||||||
initialize(self, init_type)
|
initialize(self, init_type)
|
||||||
|
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
self.spk_embedding_table = nn.Embedding(
|
||||||
|
num_embeddings=num_speakers,
|
||||||
|
embedding_dim=self.spk_embed_dim,
|
||||||
|
padding_idx=self.padding_idx)
|
||||||
|
|
||||||
|
if self.tone_embed_dim is not None:
|
||||||
|
self.tone_embedding_table = nn.Embedding(
|
||||||
|
num_embeddings=num_tones,
|
||||||
|
embedding_dim=self.tone_embed_dim,
|
||||||
|
padding_idx=self.padding_idx)
|
||||||
|
|
||||||
# get positional encoding class
|
# get positional encoding class
|
||||||
pos_enc_class = (ScaledPositionalEncoding
|
pos_enc_class = (ScaledPositionalEncoding
|
||||||
if self.use_scaled_pos_enc else PositionalEncoding)
|
if self.use_scaled_pos_enc else PositionalEncoding)
|
||||||
|
@ -156,6 +186,21 @@ class FastSpeech2(nn.Layer):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{encoder_type} is not supported.")
|
raise ValueError(f"{encoder_type} is not supported.")
|
||||||
|
|
||||||
|
# define additional projection for speaker embedding
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
if self.spk_embed_integration_type == "add":
|
||||||
|
self.spk_projection = nn.Linear(self.spk_embed_dim, adim)
|
||||||
|
else:
|
||||||
|
self.spk_projection = nn.Linear(adim + self.spk_embed_dim, adim)
|
||||||
|
|
||||||
|
# define additional projection for tone embedding
|
||||||
|
if self.tone_embed_dim is not None:
|
||||||
|
if self.tone_embed_integration_type == "add":
|
||||||
|
self.tone_projection = nn.Linear(self.tone_embed_dim, adim)
|
||||||
|
else:
|
||||||
|
self.tone_projection = nn.Linear(adim + self.tone_embed_dim,
|
||||||
|
adim)
|
||||||
|
|
||||||
# define duration predictor
|
# define duration predictor
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
idim=adim,
|
idim=adim,
|
||||||
|
@ -251,7 +296,11 @@ class FastSpeech2(nn.Layer):
|
||||||
speech_lengths: paddle.Tensor,
|
speech_lengths: paddle.Tensor,
|
||||||
durations: paddle.Tensor,
|
durations: paddle.Tensor,
|
||||||
pitch: paddle.Tensor,
|
pitch: paddle.Tensor,
|
||||||
energy: paddle.Tensor, ) -> Sequence[paddle.Tensor]:
|
energy: paddle.Tensor,
|
||||||
|
tone_id: paddle.Tensor=None,
|
||||||
|
spembs: paddle.Tensor=None,
|
||||||
|
spk_id: paddle.Tensor=None
|
||||||
|
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
||||||
"""Calculate forward propagation.
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -270,6 +319,13 @@ class FastSpeech2(nn.Layer):
|
||||||
Batch of padded token-averaged pitch (B, Tmax, 1).
|
Batch of padded token-averaged pitch (B, Tmax, 1).
|
||||||
energy : Tensor
|
energy : Tensor
|
||||||
Batch of padded token-averaged energy (B, Tmax, 1).
|
Batch of padded token-averaged energy (B, Tmax, 1).
|
||||||
|
tone_id : Tensor
|
||||||
|
Batch of padded tone ids (B, Tmax).
|
||||||
|
spembs : Tensor, optional
|
||||||
|
Batch of speaker embeddings (B, spk_embed_dim).
|
||||||
|
spk_id : Tnesor
|
||||||
|
Batch of speaker ids (B,)
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
Tensor
|
Tensor
|
||||||
|
@ -295,7 +351,16 @@ class FastSpeech2(nn.Layer):
|
||||||
|
|
||||||
# forward propagation
|
# forward propagation
|
||||||
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(
|
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(
|
||||||
xs, ilens, ys, olens, ds, ps, es, is_inference=False)
|
xs,
|
||||||
|
ilens,
|
||||||
|
olens,
|
||||||
|
ds,
|
||||||
|
ps,
|
||||||
|
es,
|
||||||
|
is_inference=False,
|
||||||
|
spembs=spembs,
|
||||||
|
spk_id=spk_id,
|
||||||
|
tone_id=tone_id)
|
||||||
# modify mod part of groundtruth
|
# modify mod part of groundtruth
|
||||||
if self.reduction_factor > 1:
|
if self.reduction_factor > 1:
|
||||||
olens = paddle.to_tensor(
|
olens = paddle.to_tensor(
|
||||||
|
@ -305,21 +370,38 @@ class FastSpeech2(nn.Layer):
|
||||||
|
|
||||||
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens
|
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens
|
||||||
|
|
||||||
def _forward(
|
def _forward(self,
|
||||||
self,
|
xs: paddle.Tensor,
|
||||||
xs: paddle.Tensor,
|
ilens: paddle.Tensor,
|
||||||
ilens: paddle.Tensor,
|
olens: paddle.Tensor=None,
|
||||||
ys: paddle.Tensor=None,
|
ds: paddle.Tensor=None,
|
||||||
olens: paddle.Tensor=None,
|
ps: paddle.Tensor=None,
|
||||||
ds: paddle.Tensor=None,
|
es: paddle.Tensor=None,
|
||||||
ps: paddle.Tensor=None,
|
is_inference: bool=False,
|
||||||
es: paddle.Tensor=None,
|
alpha: float=1.0,
|
||||||
is_inference: bool=False,
|
spembs=None,
|
||||||
alpha: float=1.0, ) -> Sequence[paddle.Tensor]:
|
spk_id=None,
|
||||||
|
tone_id=None) -> Sequence[paddle.Tensor]:
|
||||||
# forward encoder
|
# forward encoder
|
||||||
x_masks = self._source_mask(ilens)
|
x_masks = self._source_mask(ilens)
|
||||||
|
|
||||||
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)
|
# (B, Tmax, adim)
|
||||||
|
hs, _ = self.encoder(xs, x_masks)
|
||||||
|
|
||||||
|
# integrate speaker embedding
|
||||||
|
if self.spk_embed_dim is not None:
|
||||||
|
if spembs is not None:
|
||||||
|
hs = self._integrate_with_spk_embed(hs, spembs)
|
||||||
|
elif spk_id is not None:
|
||||||
|
spembs = self.spk_embedding_table(spk_id)
|
||||||
|
hs = self._integrate_with_spk_embed(hs, spembs)
|
||||||
|
|
||||||
|
# integrate tone embedding
|
||||||
|
if self.tone_embed_dim is not None:
|
||||||
|
if tone_id is not None:
|
||||||
|
tone_embs = self.tone_embedding_table(tone_id)
|
||||||
|
hs = self._integrate_with_tone_embed(hs, tone_embs)
|
||||||
|
|
||||||
# forward duration predictor and variance predictors
|
# forward duration predictor and variance predictors
|
||||||
d_masks = make_pad_mask(ilens)
|
d_masks = make_pad_mask(ilens)
|
||||||
|
|
||||||
|
@ -387,7 +469,11 @@ class FastSpeech2(nn.Layer):
|
||||||
pitch: paddle.Tensor=None,
|
pitch: paddle.Tensor=None,
|
||||||
energy: paddle.Tensor=None,
|
energy: paddle.Tensor=None,
|
||||||
alpha: float=1.0,
|
alpha: float=1.0,
|
||||||
use_teacher_forcing: bool=False, ) -> paddle.Tensor:
|
use_teacher_forcing: bool=False,
|
||||||
|
spembs=None,
|
||||||
|
spk_id=None,
|
||||||
|
tone_id=None,
|
||||||
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||||
"""Generate the sequence of features given the sequences of characters.
|
"""Generate the sequence of features given the sequences of characters.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -407,6 +493,10 @@ class FastSpeech2(nn.Layer):
|
||||||
use_teacher_forcing : bool, optional
|
use_teacher_forcing : bool, optional
|
||||||
Whether to use teacher forcing.
|
Whether to use teacher forcing.
|
||||||
If true, groundtruth of duration, pitch and energy will be used.
|
If true, groundtruth of duration, pitch and energy will be used.
|
||||||
|
spembs : Tensor, optional
|
||||||
|
peaker embedding vector (spk_embed_dim,).
|
||||||
|
spk_id : Tensor, optional
|
||||||
|
Speaker embedding vector (spk_embed_dim).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
|
@ -414,7 +504,7 @@ class FastSpeech2(nn.Layer):
|
||||||
Output sequence of features (L, odim).
|
Output sequence of features (L, odim).
|
||||||
"""
|
"""
|
||||||
x, y = text, speech
|
x, y = text, speech
|
||||||
d, p, e = durations, pitch, energy
|
spemb, d, p, e = spembs, durations, pitch, energy
|
||||||
|
|
||||||
# setup batch axis
|
# setup batch axis
|
||||||
ilens = paddle.to_tensor(
|
ilens = paddle.to_tensor(
|
||||||
|
@ -424,6 +514,11 @@ class FastSpeech2(nn.Layer):
|
||||||
if y is not None:
|
if y is not None:
|
||||||
ys = y.unsqueeze(0)
|
ys = y.unsqueeze(0)
|
||||||
|
|
||||||
|
if spemb is not None:
|
||||||
|
spembs = spemb.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
spembs = None
|
||||||
|
|
||||||
if use_teacher_forcing:
|
if use_teacher_forcing:
|
||||||
# use groundtruth of duration, pitch, and energy
|
# use groundtruth of duration, pitch, and energy
|
||||||
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0)
|
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0)
|
||||||
|
@ -434,7 +529,10 @@ class FastSpeech2(nn.Layer):
|
||||||
ys,
|
ys,
|
||||||
ds=ds,
|
ds=ds,
|
||||||
ps=ps,
|
ps=ps,
|
||||||
es=es, )
|
es=es,
|
||||||
|
spembs=spembs,
|
||||||
|
spk_id=spk_id,
|
||||||
|
tone_id=tone_id)
|
||||||
else:
|
else:
|
||||||
# (1, L, odim)
|
# (1, L, odim)
|
||||||
_, outs, *_ = self._forward(
|
_, outs, *_ = self._forward(
|
||||||
|
@ -442,10 +540,71 @@ class FastSpeech2(nn.Layer):
|
||||||
ilens,
|
ilens,
|
||||||
ys,
|
ys,
|
||||||
is_inference=True,
|
is_inference=True,
|
||||||
alpha=alpha, )
|
alpha=alpha,
|
||||||
|
spembs=spembs,
|
||||||
|
spk_id=spk_id,
|
||||||
|
tone_id=tone_id)
|
||||||
|
|
||||||
return outs[0]
|
return outs[0]
|
||||||
|
|
||||||
|
def _integrate_with_spk_embed(self, hs, spembs):
|
||||||
|
"""Integrate speaker embedding with hidden states.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hs : Tensor
|
||||||
|
Batch of hidden state sequences (B, Tmax, adim).
|
||||||
|
spembs : Tensor
|
||||||
|
Batch of speaker embeddings (B, spk_embed_dim).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
Batch of integrated hidden state sequences (B, Tmax, adim)
|
||||||
|
"""
|
||||||
|
if self.spk_embed_integration_type == "add":
|
||||||
|
# apply projection and then add to hidden states
|
||||||
|
spembs = self.spk_projection(F.normalize(spembs))
|
||||||
|
hs = hs + spembs.unsqueeze(1)
|
||||||
|
elif self.spk_embed_integration_type == "concat":
|
||||||
|
# concat hidden states with spk embeds and then apply projection
|
||||||
|
spembs = F.normalize(spembs).unsqueeze(1).expand(
|
||||||
|
shape=[-1, hs.shape[1], -1])
|
||||||
|
hs = self.spk_projection(paddle.concat([hs, spembs], axis=-1))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("support only add or concat.")
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def _integrate_with_tone_embed(self, hs, tone_embs):
|
||||||
|
"""Integrate speaker embedding with hidden states.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hs : Tensor
|
||||||
|
Batch of hidden state sequences (B, Tmax, adim).
|
||||||
|
tone_embs : Tensor
|
||||||
|
Batch of speaker embeddings (B, Tmax, tone_embed_dim).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
Tensor
|
||||||
|
Batch of integrated hidden state sequences (B, Tmax, adim)
|
||||||
|
"""
|
||||||
|
if self.tone_embed_integration_type == "add":
|
||||||
|
# apply projection and then add to hidden states
|
||||||
|
tone_embs = self.tone_projection(F.normalize(tone_embs))
|
||||||
|
hs = hs + tone_embs
|
||||||
|
|
||||||
|
elif self.tone_embed_integration_type == "concat":
|
||||||
|
# concat hidden states with tone embeds and then apply projection
|
||||||
|
tone_embs = F.normalize(tone_embs).expand(
|
||||||
|
shape=[-1, hs.shape[1], -1])
|
||||||
|
hs = self.tone_projection(paddle.concat([hs, tone_embs], axis=-1))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("support only add or concat.")
|
||||||
|
return hs
|
||||||
|
|
||||||
def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor:
|
def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor:
|
||||||
"""Make masks for self-attention.
|
"""Make masks for self-attention.
|
||||||
|
|
||||||
|
@ -496,8 +655,8 @@ class FastSpeech2Inference(nn.Layer):
|
||||||
self.normalizer = normalizer
|
self.normalizer = normalizer
|
||||||
self.acoustic_model = model
|
self.acoustic_model = model
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text, spk_id=None):
|
||||||
normalized_mel = self.acoustic_model.inference(text)
|
normalized_mel = self.acoustic_model.inference(text, spk_id=spk_id)
|
||||||
logmel = self.normalizer.inverse(normalized_mel)
|
logmel = self.normalizer.inverse(normalized_mel)
|
||||||
return logmel
|
return logmel
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ from pathlib import Path
|
||||||
|
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import get_cfg_default
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -35,8 +34,7 @@ def main():
|
||||||
"--field-name",
|
"--field-name",
|
||||||
type=str,
|
type=str,
|
||||||
help="name of the field to compute statistics for.")
|
help="name of the field to compute statistics for.")
|
||||||
parser.add_argument(
|
|
||||||
"--config", type=str, help="yaml format configuration file.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -67,11 +65,6 @@ def main():
|
||||||
)
|
)
|
||||||
logging.warning('Skip DEBUG/INFO messages')
|
logging.warning('Skip DEBUG/INFO messages')
|
||||||
|
|
||||||
config = get_cfg_default()
|
|
||||||
# load config
|
|
||||||
if args.config:
|
|
||||||
config.merge_from_file(args.config)
|
|
||||||
|
|
||||||
# check directory existence
|
# check directory existence
|
||||||
if args.output is None:
|
if args.output is None:
|
||||||
args.output = Path(
|
args.output = Path(
|
||||||
|
@ -95,7 +88,6 @@ def main():
|
||||||
scaler.partial_fit(datum[args.field_name])
|
scaler.partial_fit(datum[args.field_name])
|
||||||
|
|
||||||
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||||
|
|
||||||
np.save(str(args.output), stats.astype(np.float32), allow_pickle=False)
|
np.save(str(args.output), stats.astype(np.float32), allow_pickle=False)
|
||||||
|
|
||||||
|
|
|
@ -20,11 +20,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from parakeet.datasets.data_table import DataTable
|
|
||||||
|
|
||||||
from config import get_cfg_default
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -62,7 +60,10 @@ def main():
|
||||||
default="phone_id_map.txt ",
|
default="phone_id_map.txt ",
|
||||||
help="phone vocabulary file.")
|
help="phone vocabulary file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config", type=str, help="yaml format configuration file.")
|
"--speaker-dict",
|
||||||
|
type=str,
|
||||||
|
default="speaker_id_map.txt ",
|
||||||
|
help="speaker id map file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -88,11 +89,6 @@ def main():
|
||||||
)
|
)
|
||||||
logging.warning('Skip DEBUG/INFO messages')
|
logging.warning('Skip DEBUG/INFO messages')
|
||||||
|
|
||||||
# load config
|
|
||||||
config = get_cfg_default()
|
|
||||||
if args.config:
|
|
||||||
config.merge_from_file(args.config)
|
|
||||||
|
|
||||||
# check directory existence
|
# check directory existence
|
||||||
dumpdir = Path(args.dumpdir).resolve()
|
dumpdir = Path(args.dumpdir).resolve()
|
||||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -131,6 +127,12 @@ def main():
|
||||||
for phn, id in phn_id:
|
for phn, id in phn_id:
|
||||||
vocab_phones[phn] = int(id)
|
vocab_phones[phn] = int(id)
|
||||||
|
|
||||||
|
vocab_speaker = {}
|
||||||
|
with open(args.speaker_dict, 'rt') as f:
|
||||||
|
spk_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
for spk, id in spk_id:
|
||||||
|
vocab_speaker[spk] = int(id)
|
||||||
|
|
||||||
# process each file
|
# process each file
|
||||||
output_metadata = []
|
output_metadata = []
|
||||||
|
|
||||||
|
@ -158,8 +160,10 @@ def main():
|
||||||
energy_path = energy_dir / f"{utt_id}_energy.npy"
|
energy_path = energy_dir / f"{utt_id}_energy.npy"
|
||||||
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
|
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
|
||||||
phone_ids = [vocab_phones[p] for p in item['phones']]
|
phone_ids = [vocab_phones[p] for p in item['phones']]
|
||||||
|
spk_id = vocab_speaker[item["speaker"]]
|
||||||
record = {
|
record = {
|
||||||
"utt_id": item['utt_id'],
|
"utt_id": item['utt_id'],
|
||||||
|
"spk_id": spk_id,
|
||||||
"text": phone_ids,
|
"text": phone_ids,
|
||||||
"text_lengths": item['text_lengths'],
|
"text_lengths": item['text_lengths'],
|
||||||
"speech_lengths": item['speech_lengths'],
|
"speech_lengths": item['speech_lengths'],
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -21,12 +22,13 @@ from typing import List, Dict, Any
|
||||||
import jsonlines
|
import jsonlines
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from parakeet.data.get_feats import LogMelFBank, Energy, Pitch
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import yaml
|
||||||
from config import get_cfg_default
|
from parakeet.data.get_feats import LogMelFBank, Energy, Pitch
|
||||||
|
from yacs.config import CfgNode as Configuration
|
||||||
|
|
||||||
|
|
||||||
|
# speaker|utt_id|phn dur phn dur ...
|
||||||
def get_phn_dur(file_name):
|
def get_phn_dur(file_name):
|
||||||
'''
|
'''
|
||||||
read MFA duration.txt
|
read MFA duration.txt
|
||||||
|
@ -41,16 +43,20 @@ def get_phn_dur(file_name):
|
||||||
'''
|
'''
|
||||||
f = open(file_name, 'r')
|
f = open(file_name, 'r')
|
||||||
sentence = {}
|
sentence = {}
|
||||||
|
speaker_set = set()
|
||||||
for line in f:
|
for line in f:
|
||||||
utt = line.strip().split('|')[0]
|
line_list = line.strip().split('|')
|
||||||
p_d = line.strip().split('|')[-1]
|
utt = line_list[0]
|
||||||
|
speaker = line_list[1]
|
||||||
|
p_d = line_list[-1]
|
||||||
|
speaker_set.add(speaker)
|
||||||
phn_dur = p_d.split()
|
phn_dur = p_d.split()
|
||||||
phn = phn_dur[::2]
|
phn = phn_dur[::2]
|
||||||
dur = phn_dur[1::2]
|
dur = phn_dur[1::2]
|
||||||
assert len(phn) == len(dur)
|
assert len(phn) == len(dur)
|
||||||
sentence[utt] = (phn, [int(i) for i in dur])
|
sentence[utt] = (phn, [int(i) for i in dur], speaker)
|
||||||
f.close()
|
f.close()
|
||||||
return sentence
|
return sentence, speaker_set
|
||||||
|
|
||||||
|
|
||||||
def deal_silence(sentence):
|
def deal_silence(sentence):
|
||||||
|
@ -59,10 +65,10 @@ def deal_silence(sentence):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
sentence : Dict
|
sentence : Dict
|
||||||
sentence: {'utt': ([char], [int])}
|
sentence: {'utt': (([char], [int]), str)}
|
||||||
'''
|
'''
|
||||||
for utt in sentence:
|
for utt in sentence:
|
||||||
cur_phn, cur_dur = sentence[utt]
|
cur_phn, cur_dur, speaker = sentence[utt]
|
||||||
new_phn = []
|
new_phn = []
|
||||||
new_dur = []
|
new_dur = []
|
||||||
|
|
||||||
|
@ -83,7 +89,7 @@ def deal_silence(sentence):
|
||||||
new_phn[i] = 'spl'
|
new_phn[i] = 'spl'
|
||||||
|
|
||||||
assert len(new_phn) == len(new_dur)
|
assert len(new_phn) == len(new_dur)
|
||||||
sentence[utt] = [new_phn, new_dur]
|
sentence[utt] = [new_phn, new_dur, speaker]
|
||||||
|
|
||||||
|
|
||||||
def get_input_token(sentence, output_path):
|
def get_input_token(sentence, output_path):
|
||||||
|
@ -106,10 +112,16 @@ def get_input_token(sentence, output_path):
|
||||||
phn_token = ["<pad>", "<unk>"] + phn_token
|
phn_token = ["<pad>", "<unk>"] + phn_token
|
||||||
phn_token += [",", "。", "?", "!", "<eos>"]
|
phn_token += [",", "。", "?", "!", "<eos>"]
|
||||||
|
|
||||||
f = open(output_path, 'w')
|
with open(output_path, 'w') as f:
|
||||||
for i, phn in enumerate(phn_token):
|
for i, phn in enumerate(phn_token):
|
||||||
f.write(phn + ' ' + str(i) + '\n')
|
f.write(phn + ' ' + str(i) + '\n')
|
||||||
f.close()
|
|
||||||
|
|
||||||
|
def get_spk_id_map(speaker_set, output_path):
|
||||||
|
speakers = sorted(list(speaker_set))
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
for i, spk in enumerate(speakers):
|
||||||
|
f.write(spk + ' ' + str(i) + '\n')
|
||||||
|
|
||||||
|
|
||||||
def compare_duration_and_mel_length(sentences, utt, mel):
|
def compare_duration_and_mel_length(sentences, utt, mel):
|
||||||
|
@ -152,11 +164,14 @@ def process_sentence(config: Dict[str, Any],
|
||||||
if utt_id in sentences:
|
if utt_id in sentences:
|
||||||
# reading, resampling may occur
|
# reading, resampling may occur
|
||||||
wav, _ = librosa.load(str(fp), sr=config.fs)
|
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||||
|
if len(wav.shape) != 1 or np.abs(wav).max() > 1.0:
|
||||||
|
return record
|
||||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||||
assert np.abs(wav).max(
|
assert np.abs(wav).max(
|
||||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||||
phones = sentences[utt_id][0]
|
phones = sentences[utt_id][0]
|
||||||
durations = sentences[utt_id][1]
|
durations = sentences[utt_id][1]
|
||||||
|
speaker = sentences[utt_id][2]
|
||||||
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||||
# little imprecise than use *.TextGrid directly
|
# little imprecise than use *.TextGrid directly
|
||||||
times = librosa.frames_to_time(
|
times = librosa.frames_to_time(
|
||||||
|
@ -210,7 +225,8 @@ def process_sentence(config: Dict[str, Any],
|
||||||
# use absolute path
|
# use absolute path
|
||||||
"speech": str(mel_path.resolve()),
|
"speech": str(mel_path.resolve()),
|
||||||
"pitch": str(f0_path.resolve()),
|
"pitch": str(f0_path.resolve()),
|
||||||
"energy": str(energy_path.resolve())
|
"energy": str(energy_path.resolve()),
|
||||||
|
"speaker": speaker
|
||||||
}
|
}
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
@ -261,20 +277,34 @@ def main():
|
||||||
# parse config and args
|
# parse config and args
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Preprocess audio and then extract features.")
|
description="Preprocess audio and then extract features.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rootdir", default=None, type=str, help="directory to baker dataset.")
|
"--dataset",
|
||||||
|
default="baker",
|
||||||
|
type=str,
|
||||||
|
help="name of dataset, should in {baker, aishell3} now")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dur-file",
|
"--dur-file",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="path to baker durations.txt.")
|
help="path to baker durations.txt.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dumpdir",
|
"--dumpdir",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="directory to dump feature files.")
|
help="directory to dump feature files.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config", type=str, help="yaml format configuration file.")
|
"--config-path",
|
||||||
|
default="conf/default.yaml",
|
||||||
|
type=str,
|
||||||
|
help="yaml format configuration file.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose",
|
"--verbose",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -291,17 +321,10 @@ def main():
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="whether cut sil in the edge of audio")
|
help="whether cut sil in the edge of audio")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
C = get_cfg_default()
|
config_path = Path(args.config_path).resolve()
|
||||||
if args.config:
|
|
||||||
C.merge_from_file(args.config)
|
|
||||||
C.freeze()
|
|
||||||
|
|
||||||
if args.verbose > 1:
|
|
||||||
print(vars(args))
|
|
||||||
print(C)
|
|
||||||
|
|
||||||
root_dir = Path(args.rootdir).expanduser()
|
root_dir = Path(args.rootdir).expanduser()
|
||||||
dumpdir = Path(args.dumpdir).expanduser()
|
dumpdir = Path(args.dumpdir).expanduser()
|
||||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -310,20 +333,45 @@ def main():
|
||||||
assert root_dir.is_dir()
|
assert root_dir.is_dir()
|
||||||
assert dur_file.is_file()
|
assert dur_file.is_file()
|
||||||
|
|
||||||
sentences = get_phn_dur(dur_file)
|
with open(config_path, 'rt') as f:
|
||||||
|
_C = yaml.safe_load(f)
|
||||||
|
_C = Configuration(_C)
|
||||||
|
config = _C.clone()
|
||||||
|
|
||||||
|
if args.verbose > 1:
|
||||||
|
print(vars(args))
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
sentences, speaker_set = get_phn_dur(dur_file)
|
||||||
|
|
||||||
deal_silence(sentences)
|
deal_silence(sentences)
|
||||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||||
|
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||||
get_input_token(sentences, phone_id_map_path)
|
get_input_token(sentences, phone_id_map_path)
|
||||||
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
|
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||||
|
|
||||||
# split data into 3 sections
|
if args.dataset == "baker":
|
||||||
num_train = 9800
|
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
|
||||||
num_dev = 100
|
# split data into 3 sections
|
||||||
|
num_train = 9800
|
||||||
train_wav_files = wav_files[:num_train]
|
num_dev = 100
|
||||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
train_wav_files = wav_files[:num_train]
|
||||||
test_wav_files = wav_files[num_train + num_dev:]
|
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||||
|
test_wav_files = wav_files[num_train + num_dev:]
|
||||||
|
elif args.dataset == "aishell3":
|
||||||
|
sub_num_dev = 5
|
||||||
|
wav_dir = root_dir / "train" / "wav"
|
||||||
|
train_wav_files = []
|
||||||
|
dev_wav_files = []
|
||||||
|
test_wav_files = []
|
||||||
|
for speaker in os.listdir(wav_dir):
|
||||||
|
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
|
||||||
|
if len(wav_files) > 100:
|
||||||
|
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||||
|
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||||
|
test_wav_files += wav_files[-sub_num_dev:]
|
||||||
|
else:
|
||||||
|
train_wav_files += wav_files
|
||||||
|
|
||||||
train_dump_dir = dumpdir / "train" / "raw"
|
train_dump_dir = dumpdir / "train" / "raw"
|
||||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -334,55 +382,59 @@ def main():
|
||||||
|
|
||||||
# Extractor
|
# Extractor
|
||||||
mel_extractor = LogMelFBank(
|
mel_extractor = LogMelFBank(
|
||||||
sr=C.fs,
|
sr=config.fs,
|
||||||
n_fft=C.n_fft,
|
n_fft=config.n_fft,
|
||||||
hop_length=C.n_shift,
|
hop_length=config.n_shift,
|
||||||
win_length=C.win_length,
|
win_length=config.win_length,
|
||||||
window=C.window,
|
window=config.window,
|
||||||
n_mels=C.n_mels,
|
n_mels=config.n_mels,
|
||||||
fmin=C.fmin,
|
fmin=config.fmin,
|
||||||
fmax=C.fmax)
|
fmax=config.fmax)
|
||||||
pitch_extractor = Pitch(
|
pitch_extractor = Pitch(
|
||||||
sr=C.fs, hop_length=C.n_shift, f0min=C.f0min, f0max=C.f0max)
|
sr=config.fs,
|
||||||
|
hop_length=config.n_shift,
|
||||||
|
f0min=config.f0min,
|
||||||
|
f0max=config.f0max)
|
||||||
energy_extractor = Energy(
|
energy_extractor = Energy(
|
||||||
sr=C.fs,
|
sr=config.fs,
|
||||||
n_fft=C.n_fft,
|
n_fft=config.n_fft,
|
||||||
hop_length=C.n_shift,
|
hop_length=config.n_shift,
|
||||||
win_length=C.win_length,
|
win_length=config.win_length,
|
||||||
window=C.window)
|
window=config.window)
|
||||||
|
|
||||||
# process for the 3 sections
|
# process for the 3 sections
|
||||||
|
if train_wav_files:
|
||||||
process_sentences(
|
process_sentences(
|
||||||
C,
|
config,
|
||||||
train_wav_files,
|
train_wav_files,
|
||||||
sentences,
|
sentences,
|
||||||
train_dump_dir,
|
train_dump_dir,
|
||||||
mel_extractor,
|
mel_extractor,
|
||||||
pitch_extractor,
|
pitch_extractor,
|
||||||
energy_extractor,
|
energy_extractor,
|
||||||
nprocs=args.num_cpu,
|
nprocs=args.num_cpu,
|
||||||
cut_sil=args.cut_sil)
|
cut_sil=args.cut_sil)
|
||||||
process_sentences(
|
if dev_wav_files:
|
||||||
C,
|
process_sentences(
|
||||||
dev_wav_files,
|
config,
|
||||||
sentences,
|
dev_wav_files,
|
||||||
dev_dump_dir,
|
sentences,
|
||||||
mel_extractor,
|
dev_dump_dir,
|
||||||
pitch_extractor,
|
mel_extractor,
|
||||||
energy_extractor,
|
pitch_extractor,
|
||||||
cut_sil=args.cut_sil)
|
energy_extractor,
|
||||||
|
cut_sil=args.cut_sil)
|
||||||
process_sentences(
|
if test_wav_files:
|
||||||
C,
|
process_sentences(
|
||||||
test_wav_files,
|
config,
|
||||||
sentences,
|
test_wav_files,
|
||||||
test_dump_dir,
|
sentences,
|
||||||
mel_extractor,
|
test_dump_dir,
|
||||||
pitch_extractor,
|
mel_extractor,
|
||||||
energy_extractor,
|
pitch_extractor,
|
||||||
nprocs=args.num_cpu,
|
energy_extractor,
|
||||||
cut_sil=args.cut_sil)
|
nprocs=args.num_cpu,
|
||||||
|
cut_sil=args.cut_sil)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -17,11 +17,10 @@ from pathlib import Path
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import get_cfg_default
|
|
||||||
from praatio import tgio
|
from praatio import tgio
|
||||||
|
|
||||||
|
|
||||||
def readtg(config, tg_path):
|
def readtg(tg_path, sample_rate=24000, n_shift=300):
|
||||||
alignment = tgio.openTextgrid(tg_path, readRaw=True)
|
alignment = tgio.openTextgrid(tg_path, readRaw=True)
|
||||||
phones = []
|
phones = []
|
||||||
ends = []
|
ends = []
|
||||||
|
@ -29,40 +28,55 @@ def readtg(config, tg_path):
|
||||||
phone = interval.label
|
phone = interval.label
|
||||||
phones.append(phone)
|
phones.append(phone)
|
||||||
ends.append(interval.end)
|
ends.append(interval.end)
|
||||||
frame_pos = librosa.time_to_frames(
|
frame_pos = librosa.time_to_frames(ends, sr=sample_rate, hop_length=n_shift)
|
||||||
ends, sr=config.fs, hop_length=config.n_shift)
|
|
||||||
durations = np.diff(frame_pos, prepend=0)
|
durations = np.diff(frame_pos, prepend=0)
|
||||||
assert len(durations) == len(phones)
|
assert len(durations) == len(phones)
|
||||||
# merge "" and sp in the end
|
# merge "" and sp in the end
|
||||||
if phones[-1] == "":
|
if phones[-1] == "" and len(phones) > 1 and phones[-2] == "sp":
|
||||||
phones = phones[:-1]
|
phones = phones[:-1]
|
||||||
durations[-2] += durations[-1]
|
durations[-2] += durations[-1]
|
||||||
durations = durations[:-1]
|
durations = durations[:-1]
|
||||||
# replace the last sp with sil
|
# replace the last "sp" with "sil" in MFA1.x
|
||||||
phones[-1] = "sil" if phones[-1] == "sp" else phones[-1]
|
phones[-1] = "sil" if phones[-1] == "sp" else phones[-1]
|
||||||
|
# replace the edge "" with "sil", replace the inner "" with "sp"
|
||||||
|
new_phones = []
|
||||||
|
for i, phn in enumerate(phones):
|
||||||
|
if phn == "":
|
||||||
|
if i in {0, len(phones) - 1}:
|
||||||
|
new_phones.append("sil")
|
||||||
|
else:
|
||||||
|
new_phones.append("sp")
|
||||||
|
else:
|
||||||
|
new_phones.append(phn)
|
||||||
|
phones = new_phones
|
||||||
results = ""
|
results = ""
|
||||||
|
|
||||||
for (p, d) in zip(phones, durations):
|
for (p, d) in zip(phones, durations):
|
||||||
results += p + " " + str(d) + " "
|
results += p + " " + str(d) + " "
|
||||||
return results.strip()
|
return results.strip()
|
||||||
|
|
||||||
|
|
||||||
# assume that the directory structure of inputdir is inputdir/speaker/*.TextGrid
|
# assume that the directory structure of inputdir is inputdir/speaker/*.TextGrid
|
||||||
# in MFA1.x, there are blank labels("") in the end, we replace it with "sil"
|
# in MFA1.x, there are blank labels("") in the end, and maybe "sp" before it
|
||||||
def gen_duration_from_textgrid(config, inputdir, output):
|
# in MFA2.x, there are blank labels("") in the begin and the end, while no "sp" and "sil" anymore
|
||||||
|
# we replace it with "sil"
|
||||||
|
def gen_duration_from_textgrid(inputdir, output, sample_rate=24000,
|
||||||
|
n_shift=300):
|
||||||
|
# key: utt_id, value: (speaker, phn_durs)
|
||||||
durations_dict = {}
|
durations_dict = {}
|
||||||
|
list_dir = os.listdir(inputdir)
|
||||||
for speaker in os.listdir(inputdir):
|
speakers = [dir for dir in list_dir if os.path.isdir(inputdir / dir)]
|
||||||
|
for speaker in speakers:
|
||||||
subdir = inputdir / speaker
|
subdir = inputdir / speaker
|
||||||
for file in os.listdir(subdir):
|
for file in os.listdir(subdir):
|
||||||
if file.endswith(".TextGrid"):
|
if file.endswith(".TextGrid"):
|
||||||
tg_path = subdir / file
|
tg_path = subdir / file
|
||||||
name = file.split(".")[0]
|
name = file.split(".")[0]
|
||||||
durations_dict[name] = readtg(config, tg_path)
|
durations_dict[name] = (speaker, readtg(
|
||||||
|
tg_path, sample_rate=sample_rate, n_shift=n_shift))
|
||||||
with open(output, "w") as wf:
|
with open(output, "w") as wf:
|
||||||
for name in sorted(durations_dict.keys()):
|
for name in sorted(durations_dict.keys()):
|
||||||
wf.write(name + "|" + durations_dict[name] + "\n")
|
wf.write(name + "|" + durations_dict[name][0] + "|" +
|
||||||
|
durations_dict[name][1] + "\n")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -75,19 +89,18 @@ def main():
|
||||||
type=str,
|
type=str,
|
||||||
help="directory to alignment files.")
|
help="directory to alignment files.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output", type=str, required=True, help="output duration file name")
|
"--output", type=str, required=True, help="output duration file.")
|
||||||
|
parser.add_argument("--sample-rate", type=int, help="the sample of wavs.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config", type=str, help="yaml format configuration file.")
|
"--n-shift",
|
||||||
|
type=int,
|
||||||
|
help="the n_shift of time_to_freames, also called hop_length.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
C = get_cfg_default()
|
|
||||||
if args.config:
|
|
||||||
C.merge_from_file(args.config)
|
|
||||||
C.freeze()
|
|
||||||
|
|
||||||
inputdir = Path(args.inputdir).expanduser()
|
inputdir = Path(args.inputdir).expanduser()
|
||||||
output = Path(args.output).expanduser()
|
output = Path(args.output).expanduser()
|
||||||
gen_duration_from_textgrid(C, inputdir, output)
|
gen_duration_from_textgrid(inputdir, output, args.sample_rate, args.n_shift)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
Loading…
Reference in New Issue