add fastspeech2 example inference

This commit is contained in:
TianYuan 2021-07-22 10:31:34 +00:00
parent 47ec051136
commit 3d39385d5e
13 changed files with 490 additions and 17 deletions

View File

@ -0,0 +1,42 @@
# FastSpeech2 with BZNSYP
------
## Dataset
-----
### Download and Extract the datasaet.
Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/index/source).
### Get MFA result of BZNSYP and Extract it.
we use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
you can download from here, or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
### Preprocess the dataset.
Assume the path to the dataset is `~/datasets/BZNSYP`.
Assume the path to the MFA result of BZNSYP is `./baker_alignment_tone`.
Run the command below to preprocess the dataset.
```bash
./preprocess.sh
```
## Train the model
---
```bash
./run.sh
```
## Synthesize
---
we use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
`synthesize.sh` can synthesize waveform for `metadata.jsonl`.
`synthesize_e2e.sh` can synthesize waveform for text list.
```bash
./synthesize.sh
```
or
```bash
./synthesize_e2e.sh
```
you can see the bash files for more datails of input parameter.
## Pretrained Model

View File

@ -89,7 +89,7 @@ updater:
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.0001 # learning rate
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #

View File

@ -0,0 +1,76 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import paddle
from pypinyin import lazy_pinyin, Style
import jieba
class Frontend():
def __init__(self, vocab_path):
self.voc_phones = {}
with open(vocab_path, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
self.voc_phones[phn] = int(id)
def segment(self, sentence):
segments = re.split(r'[:,;。?!]', sentence)
segments = [seg for seg in segments if len(seg)]
return segments
def g2p(self, sentence):
segments = self.segment(sentence)
phones = []
for seg in segments:
seg = jieba.lcut(seg)
initials = lazy_pinyin(
seg, neutral_tone_with_five=True, style=Style.INITIALS)
finals = lazy_pinyin(
seg, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if re.match(r'i\d', v):
if c in ['z', 'c', 's']:
v = re.sub('i', 'ii', v)
elif c in ['zh', 'ch', 'sh', 'r']:
v = re.sub('i', 'iii', v)
if c:
phones.append(c)
if v:
phones.append(v)
# add sp between sentence
phones.append('sp')
# replace last sp with <eos>
phones[-1] = '<eos>'
return phones
def p2id(self, phonemes):
# replace unk phone with sp
phonemes = [
phn if phn in self.voc_phones else "sp" for phn in phonemes
]
phone_ids = [self.voc_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def text_analysis(self, sentence):
phonemes = self.g2p(sentence)
phone_ids = self.p2id(phonemes)
phone_ids = paddle.to_tensor(phone_ids)
return phone_ids

View File

@ -45,17 +45,17 @@ def main():
required=True,
help="directory to dump normalized feature files.")
parser.add_argument(
"--speech_stats",
"--speech-stats",
type=str,
required=True,
help="speech statistics file.")
parser.add_argument(
"--pitch_stats",
"--pitch-stats",
type=str,
required=True,
help="pitch statistics file.")
parser.add_argument(
"--energy_stats",
"--energy-stats",
type=str,
required=True,
help="energy statistics file.")

View File

@ -258,7 +258,7 @@ def main():
type=str,
help="directory to baker dataset.")
parser.add_argument(
"--dur_path",
"--dur-path",
default=None,
type=str,
help="path to baker durations.txt.")
@ -275,7 +275,7 @@ def main():
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument(
"--num_cpu", type=int, default=1, help="number of process.")
"--num-cpu", type=int, default=1, help="number of process.")
args = parser.parse_args()
C = get_cfg_default()

View File

@ -4,7 +4,7 @@
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
# extract features
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur_path durations.txt --num_cpu 16
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-path durations.txt --num-cpu 4
# # get features' stats(mean and std)
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
@ -12,7 +12,7 @@ python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-n
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="energy"
# normalize and covert phone to id, dev and test should use train's stats
python3 normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --speech_stats=dump/train/speech_stats.npy --pitch_stats=dump/train/pitch_stats.npy --energy_stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --speech_stats=dump/train/speech_stats.npy --pitch_stats=dump/train/pitch_stats.npy --energy_stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --speech_stats=dump/train/speech_stats.npy --pitch_stats=dump/train/pitch_stats.npy --energy_stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt

0
examples/fastspeech2/baker/run.sh Normal file → Executable file
View File

View File

@ -0,0 +1,16 @@
001 凯莫瑞安联合体的经济崩溃,迫在眉睫。
002 对于所有想要离开那片废土,去寻找更美好生活的人来说。
003 克哈,是你们所有人安全的港湾。
004 为了保护尤摩扬人民不受异虫的残害,我所做的,比他们自己的领导委员会都多。
005 无论他们如何诽谤我,我将继续为所有泰伦人的最大利益,而努力奋斗。
006 身为你们的元首,我带领泰伦人实现了人类统治领地和经济的扩张。
007 我们将继续成长,用行动回击那些只会说风凉话,不愿意和我们相向而行的害群之马。
008 帝国武装力量,无数的优秀儿女,正时刻守卫着我们的家园大门,但是他们孤木难支。
009 凡是今天应征入伍者,所获的所有刑罚罪责,减半。
010 激进分子和异见者希望你们一听见枪声,就背弃多年的和平与繁荣。
011 他们没有勇气和能力,带领人类穿越一个充满危险的星系。
012 法治是我们的命脉,然而它却受到前所未有的挑战。
013 我将恢复我们帝国的荣光,绝不会向任何外星势力低头。
014 我已经驯服了异虫,荡平了星灵。如今它们的创造者,想要夺走我们拥有的一切。
015 永远记住,谁才是最能保护你们的人。
016 不要听信别人的谗言,我不是什么克隆人。

View File

@ -0,0 +1,148 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.modules.normalizer import ZScore
def evaluate(args, fastspeech2_config, pwg_config):
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader)
test_dataset = DataTable(data=test_metadata, fields=["utt_id", "text"])
with open(args.phones, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = fastspeech2_config.n_mels
model = FastSpeech2(
idim=vocab_size, odim=odim, **fastspeech2_config["model"])
model.set_state_dict(
paddle.load(args.fastspeech2_checkpoint)["main_params"])
model.eval()
vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
stat = np.load(args.fastspeech2_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
fastspeech2_normalizer = ZScore(mu, std)
stat = np.load(args.pwg_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
fastspeech2_inferencce = FastSpeech2Inference(fastspeech2_normalizer,
model)
pwg_inference = PWGInference(pwg_normalizer, vocoder)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for datum in test_dataset:
utt_id = datum["utt_id"]
text = paddle.to_tensor(datum["text"])
with paddle.no_grad():
wav = pwg_inference(fastspeech2_inferencce(text))
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=fastspeech2_config.fs)
print(f"{utt_id} done!")
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config",
type=str,
help="config file to overwrite default config")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
help="fastspeech2 checkpoint to load.")
parser.add_argument(
"--fastspeech2-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--pwg-params",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
"--pwg-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--phones",
type=str,
default="phone_id_map.txt ",
help="phone vocabulary file.")
parser.add_argument("--test-metadata", type=str, help="test metadata")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
with open(args.fastspeech2_config) as f:
fastspeech2_config = CfgNode(yaml.safe_load(f))
with open(args.pwg_config) as f:
pwg_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(fastspeech2_config)
print(pwg_config)
evaluate(args, fastspeech2_config, pwg_config)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,14 @@
#!/bin/bash
python3 synthesize.py \
--fastspeech2-config=conf/default.yaml \
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_62577.pdz \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=pwg_default.yaml \
--pwg-params=pwg_generator.pdparams \
--pwg-stat=pwg_stats.npy \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=exp/debug/test \
--device="gpu" \
--phones=dump/phone_id_map.txt

View File

@ -0,0 +1,154 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.modules.normalizer import ZScore
from frontend import Frontend
def evaluate(args, fastspeech2_config, pwg_config):
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f:
for line in f:
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
with open(args.phones, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = fastspeech2_config.n_mels
model = FastSpeech2(
idim=vocab_size, odim=odim, **fastspeech2_config["model"])
model.set_state_dict(
paddle.load(args.fastspeech2_checkpoint)["main_params"])
model.eval()
vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_params))
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
frontend = Frontend(args.phones)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
fastspeech2_normalizer = ZScore(mu, std)
stat = np.load(args.pwg_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
fastspeech2_inferencce = FastSpeech2Inference(fastspeech2_normalizer,
model)
pwg_inference = PWGInference(pwg_normalizer, vocoder)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for utt_id, sentence in sentences:
phone_ids = frontend.text_analysis(sentence)
with paddle.no_grad():
wav = pwg_inference(fastspeech2_inferencce(phone_ids))
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=fastspeech2_config.fs)
print(f"{utt_id} done!")
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config",
type=str,
help="config file to overwrite default config")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
help="fastspeech2 checkpoint to load.")
parser.add_argument(
"--fastspeech2-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--pwg-params",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
"--pwg-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--phones",
type=str,
default="phone_id_map.txt ",
help="phone vocabulary file.")
parser.add_argument(
"--text",
type=str,
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
with open(args.fastspeech2_config) as f:
fastspeech2_config = CfgNode(yaml.safe_load(f))
with open(args.pwg_config) as f:
pwg_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(fastspeech2_config)
print(pwg_config)
evaluate(args, fastspeech2_config, pwg_config)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,14 @@
#!/bin/bash
python3 synthesize_e2e.py \
--fastspeech2-config=conf/default.yaml \
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_136017.pdz \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=pwg_default.yaml \
--pwg-params=pwg_generator.pdparams \
--pwg-stat=pwg_stats.npy \
--text=sentences.txt \
--output-dir=exp/debug/test_e2e \
--device="gpu" \
--phones=dump/phone_id_map.txt

View File

@ -397,11 +397,11 @@ class FastSpeech2(nn.Layer):
speech : Tensor, optional
Feature sequence to extract style (N, idim).
durations : LongTensor, optional
Groundtruth of duration (T + 1,).
Groundtruth of duration (T,).
pitch : Tensor, optional
Groundtruth of token-averaged pitch (T + 1, 1).
Groundtruth of token-averaged pitch (T, 1).
energy : Tensor, optional
Groundtruth of token-averaged energy (T + 1, 1).
Groundtruth of token-averaged energy (T, 1).
alpha : float, optional
Alpha to control the speed.
use_teacher_forcing : bool, optional
@ -412,9 +412,6 @@ class FastSpeech2(nn.Layer):
----------
Tensor
Output sequence of features (L, odim).
None
Dummy for compatibility.
"""
x, y = text, speech
d, p, e = durations, pitch, energy
@ -455,7 +452,7 @@ class FastSpeech2(nn.Layer):
is_inference=True,
alpha=alpha, )
return outs[0], None, None
return outs[0]
def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor:
"""Make masks for self-attention.
@ -501,6 +498,18 @@ class FastSpeech2(nn.Layer):
init_dec_alpha))
class FastSpeech2Inference(nn.Layer):
def __init__(self, normalizer, model):
super().__init__()
self.normalizer = normalizer
self.acoustic_model = model
def forward(self, text):
normalized_mel = self.acoustic_model.inference(text)
logmel = self.normalizer.inverse(normalized_mel)
return logmel
class FastSpeech2Loss(nn.Layer):
"""Loss function module for FastSpeech2."""