add e2e inference script
This commit is contained in:
parent
acc02c9b79
commit
a62eeb9b06
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import paddle
|
||||
import pypinyin
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
import jieba
|
||||
import phkit
|
||||
phkit.initialize()
|
||||
from parakeet.frontend.vocab import Vocab
|
||||
|
||||
with open("phones.txt", 'rt') as f:
|
||||
phones = [line.strip() for line in f.readlines()]
|
||||
|
||||
with open("tones.txt", 'rt') as f:
|
||||
tones = [line.strip() for line in f.readlines()]
|
||||
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
|
||||
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)
|
||||
|
||||
|
||||
def segment(sentence):
|
||||
segments = re.split(r'[:,;。?!]', sentence)
|
||||
segments = [seg for seg in segments if len(seg)]
|
||||
return segments
|
||||
|
||||
|
||||
def g2p(sentence):
|
||||
segments = segment(sentence)
|
||||
phones = []
|
||||
phones.append('sil')
|
||||
tones = []
|
||||
tones.append('0')
|
||||
|
||||
for seg in segments:
|
||||
seg = jieba.lcut(seg)
|
||||
initials = lazy_pinyin(
|
||||
seg, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||
finals = lazy_pinyin(
|
||||
seg, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for c, v in zip(initials, finals):
|
||||
# NOTE: post process for pypinyin outputs
|
||||
# we discriminate i, ii and iii
|
||||
if re.match(r'i\d', v):
|
||||
if c in ['z', 'c', 's']:
|
||||
v = re.sub('i', 'ii', v)
|
||||
elif c in ['zh', 'ch', 'sh', 'r']:
|
||||
v = re.sub('i', 'iii', v)
|
||||
if c:
|
||||
phones.append(c)
|
||||
tones.append('0')
|
||||
if v:
|
||||
phones.append(v[:-1])
|
||||
tones.append(v[-1])
|
||||
phones.append('sp')
|
||||
tones.append('0')
|
||||
phones[-1] = 'sil'
|
||||
tones[-1] = '0'
|
||||
return (phones, tones)
|
||||
|
||||
|
||||
def p2id(voc, phonemes):
|
||||
phone_ids = [voc.lookup(item) for item in phonemes]
|
||||
return np.array(phone_ids, np.int64)
|
||||
|
||||
|
||||
def t2id(voc, tones):
|
||||
tone_ids = [voc.lookup(item) for item in tones]
|
||||
return np.array(tone_ids, np.int64)
|
||||
|
||||
|
||||
def text_analysis(sentence):
|
||||
phonemes, tones = g2p(sentence)
|
||||
print(sentence)
|
||||
print([p + t if t != '0' else p for p, t in zip(phonemes, tones)])
|
||||
phone_ids = p2id(voc_phones, phonemes)
|
||||
tone_ids = t2id(voc_tones, tones)
|
||||
phones = paddle.to_tensor(phone_ids)
|
||||
tones = paddle.to_tensor(tone_ids)
|
||||
return phones, tones
|
|
@ -0,0 +1,16 @@
|
|||
001 凯莫瑞安联合体的经济崩溃,迫在眉睫。
|
||||
002 对于所有想要离开那片废土,去寻找更美好生活的人来说。
|
||||
003 克哈,是你们所有人安全的港湾。
|
||||
004 为了保护尤摩扬人民不受异虫的残害,我所做的,比他们自己的领导委员会都多。
|
||||
005 无论他们如何诽谤我,我将继续为所有泰伦人的最大利益,而努力奋斗。
|
||||
006 身为你们的元首,我带领泰伦人实现了人类统治领地和经济的扩张。
|
||||
007 我们将继续成长,用行动回击那些只会说风凉话,不愿意和我们相向而行的害群之马。
|
||||
008 帝国武装力量,无数的优秀儿女,正时刻守卫着我们的家园大门,但是他们孤木难支。
|
||||
009 凡是今天应征入伍者,所获的所有刑罚罪责,减半。
|
||||
010 激进分子和异见者希望你们一听见枪声,就背弃多年的和平与繁荣。
|
||||
011 他们没有勇气和能力,带领人类穿越一个充满危险的星系。
|
||||
012 法治是我们的命脉,然而它却受到前所未有的挑战。
|
||||
013 我将恢复我们帝国的荣光,绝不会向任何外星势力低头。
|
||||
014 我已经驯服了异虫,荡平了星灵。如今它们的创造者,想要夺走我们拥有的一切。
|
||||
015 永远记住,谁才是最能保护你们的人。
|
||||
016 不要听信别人的谗言,我不是什么克隆人。
|
|
@ -121,7 +121,7 @@ def main():
|
|||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument("--test-metadata", type=str, help="training data")
|
||||
parser.add_argument("--test-metadata", type=str, help="test metadata")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distributed as dist
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.speedyspeech import SpeedySpeech, SpeedySpeechInference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
from frontend import text_analysis
|
||||
|
||||
|
||||
def evaluate(args, speedyspeech_config, pwg_config):
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for evaluation
|
||||
sentences = []
|
||||
with open(args.text, 'rt') as f:
|
||||
for line in f:
|
||||
utt_id, sentence = line.strip().split()
|
||||
sentences.append((utt_id, sentence))
|
||||
|
||||
model = SpeedySpeech(**speedyspeech_config["model"])
|
||||
model.set_state_dict(
|
||||
paddle.load(args.speedyspeech_checkpoint)["main_params"])
|
||||
model.eval()
|
||||
|
||||
vocoder = PWGGenerator(**pwg_config["generator_params"])
|
||||
vocoder.set_state_dict(paddle.load(args.pwg_params))
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder.eval()
|
||||
print("model done!")
|
||||
|
||||
stat = np.load(args.speedyspeech_stat)
|
||||
mu, std = stat
|
||||
mu = paddle.to_tensor(mu)
|
||||
std = paddle.to_tensor(std)
|
||||
speedyspeech_normalizer = ZScore(mu, std)
|
||||
|
||||
stat = np.load(args.pwg_stat)
|
||||
mu, std = stat
|
||||
mu = paddle.to_tensor(mu)
|
||||
std = paddle.to_tensor(std)
|
||||
pwg_normalizer = ZScore(mu, std)
|
||||
|
||||
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
|
||||
model)
|
||||
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for utt_id, sentence in sentences:
|
||||
phones, tones = text_analysis(sentence)
|
||||
|
||||
with paddle.no_grad():
|
||||
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
|
||||
sf.write(
|
||||
output_dir / (utt_id + ".wav"),
|
||||
wav.numpy(),
|
||||
samplerate=speedyspeech_config.sr)
|
||||
print(f"{utt_id} done!")
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with speedyspeech & parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--speedyspeech-config",
|
||||
type=str,
|
||||
help="config file to overwrite default config")
|
||||
parser.add_argument(
|
||||
"--speedyspeech-checkpoint",
|
||||
type=str,
|
||||
help="speedyspeech checkpoint to load.")
|
||||
parser.add_argument(
|
||||
"--speedyspeech-stat",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
help="parallel wavegan generator parameters to load.")
|
||||
parser.add_argument(
|
||||
"--pwg-stat",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
help="text to synthesize, a 'utt_id sentence' pair per line")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
with open(args.speedyspeech_config) as f:
|
||||
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
||||
with open(args.pwg_config) as f:
|
||||
pwg_config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(speedyspeech_config)
|
||||
print(pwg_config)
|
||||
|
||||
evaluate(args, speedyspeech_config, pwg_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue