restructure frontend example
This commit is contained in:
parent
309228ddbf
commit
e8991c973c
|
@ -61,14 +61,18 @@ class Frontend():
|
|||
phone = match.group(1)
|
||||
tone = match.group(2)
|
||||
# if the merged erhua not in the vocab
|
||||
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, we split 'iaor' into ['iao','er']
|
||||
# and the tones accordingly change from ['3'] to ['3','2'], while '2' is the tone of 'er2'
|
||||
if len(phone) >= 2 and phone != "er" and phone[
|
||||
-1] == 'r' and phone not in self.vocab_phones and phone[:
|
||||
-1] in self.vocab_phones:
|
||||
phones.append(phone[:-1])
|
||||
phones.append("er")
|
||||
else:
|
||||
tones.append(tone)
|
||||
tones.append("2")
|
||||
else:
|
||||
phones.append(phone)
|
||||
tones.append(tone)
|
||||
else:
|
||||
phones.append(full_phone)
|
||||
tones.append('0')
|
||||
|
@ -76,9 +80,11 @@ class Frontend():
|
|||
tone_ids = paddle.to_tensor(tone_ids)
|
||||
result["tone_ids"] = tone_ids
|
||||
else:
|
||||
# if the merged erhua not in the vocab
|
||||
|
||||
phones = []
|
||||
for phone in phonemes:
|
||||
# if the merged erhua not in the vocab
|
||||
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2']
|
||||
if len(phone) >= 3 and phone[:-1] != "er" and phone[
|
||||
-2] == 'r' and phone not in self.vocab_phones and (
|
||||
phone[:-2] + phone[-1]) in self.vocab_phones:
|
||||
|
|
|
@ -15,6 +15,6 @@ Run the command below to get the results of test.
|
|||
```bash
|
||||
./run.sh
|
||||
```
|
||||
The `avg WER` of g2p is: 0.02785753389811866
|
||||
The `avg WER` of g2p is: 0.027124048652822204
|
||||
|
||||
The `avg CER` of text normalization is: 0.014229233983486172
|
||||
The `avg CER` of text normalization is: 0.0061629764893859846
|
||||
|
|
|
@ -120,4 +120,6 @@ iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这
|
|||
今年有望超三百亿美元|今年有望超三百亿美元
|
||||
就连一向看多的任志强|就连一向看多的任志强
|
||||
近期也一反常态地发表看空言论|近期也一反常态地发表看空言论
|
||||
985|九八五
|
||||
985|九八五
|
||||
12~23|十二到二十三
|
||||
12-23|十二到二十三
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from praatio import tgio
|
||||
|
||||
|
||||
def get_baker_data(root_dir):
|
||||
alignment_files = sorted(
|
||||
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
|
||||
text_file = Path(text_file).expanduser()
|
||||
# filter out several files that have errors in annotation
|
||||
exclude = {'000611', '000662', '002365', '005107'}
|
||||
alignment_files = [f for f in alignment_files if f.stem not in exclude]
|
||||
data_dict = defaultdict(dict)
|
||||
for alignment_fp in alignment_files:
|
||||
alignment = tgio.openTextgrid(alignment_fp)
|
||||
# only with baker's annotation
|
||||
utt_id = alignment.tierNameList[0].split(".")[0]
|
||||
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||
phones = []
|
||||
for interval in intervals:
|
||||
label = interval.label
|
||||
phones.append(label)
|
||||
data_dict[utt_id]["phones"] = phones
|
||||
for line in open(text_file, "r"):
|
||||
if line.startswith("0"):
|
||||
utt_id, raw_text = line.strip().split()
|
||||
if utt_id in data_dict:
|
||||
data_dict[utt_id]['text'] = raw_text
|
||||
else:
|
||||
pinyin = line.strip().split()
|
||||
if utt_id in data_dict:
|
||||
data_dict[utt_id]['pinyin'] = pinyin
|
||||
return data_dict
|
||||
|
||||
|
||||
def get_g2p_phones(data_dict, frontend):
|
||||
for utt_id in data_dict:
|
||||
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
|
||||
data_dict[utt_id]["g2p_phones"] = g2p_phones
|
||||
return data_dict
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="g2p example.")
|
||||
parser.add_argument(
|
||||
"--root-dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to baker dataset.")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="data/g2p",
|
||||
type=str,
|
||||
help="directory to output.")
|
||||
|
||||
args = parser.parse_args()
|
||||
root_dir = Path(args.root_dir).expanduser()
|
||||
output_dir = Path(args.output_dir).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
assert root_dir.is_dir()
|
||||
data_dict = get_baker_data(root_dir)
|
||||
raw_path = output_dir / "text"
|
||||
ref_path = output_dir / "text.ref"
|
||||
wf_raw = open(raw_path, "w")
|
||||
wf_ref = open(ref_path, "w")
|
||||
for utt_id in data_dict:
|
||||
wf_raw.write(utt_id + " " + data_dict[utt_id]['text'] + "\n")
|
||||
wf_ref.write(utt_id + " " + " ".join(data_dict[utt_id]['phones']) +
|
||||
"\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="text normalization example.")
|
||||
parser.add_argument(
|
||||
"--test-file",
|
||||
default="data/textnorm_test_cases.txt",
|
||||
type=str,
|
||||
help="path of text normalization test file.")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="data/textnorm",
|
||||
type=str,
|
||||
help="directory to output.")
|
||||
|
||||
args = parser.parse_args()
|
||||
test_file = Path(args.test_file).expanduser()
|
||||
output_dir = Path(args.output_dir).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
raw_path = output_dir / "text"
|
||||
ref_path = output_dir / "text.ref"
|
||||
wf_raw = open(raw_path, "w")
|
||||
wf_ref = open(ref_path, "w")
|
||||
|
||||
with open(test_file, "r") as rf:
|
||||
for i, line in enumerate(rf):
|
||||
raw_text, normed_text = line.strip().split("|")
|
||||
wf_raw.write("utt_" + str(i) + " " + raw_text + "\n")
|
||||
wf_ref.write("utt_" + str(i) + " " + normed_text + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,8 +1,14 @@
|
|||
#!/bin/bash
|
||||
|
||||
# test g2p
|
||||
echo "Start get g2p test data."
|
||||
python3 get_g2p_data.py --root-dir=~/datasets/BZNSYP --output-dir=data/g2p
|
||||
echo "Start test g2p."
|
||||
python3 test_g2p.py --root-dir=~/datasets/BZNSYP
|
||||
python3 test_g2p.py --input-dir=data/g2p --output-dir=exp/g2p
|
||||
|
||||
# test text normalization
|
||||
echo "Start get text normalization test data."
|
||||
python3 get_textnorm_data.py --test-file=data/textnorm_test_cases.txt --output-dir=data/textnorm
|
||||
echo "Start test text normalization."
|
||||
python3 test_textnorm.py --test-file=data/textnorm_test_cases.txt
|
||||
python3 test_textnorm.py --input-dir=data/textnorm --output-dir=exp/textnorm
|
||||
|
||||
|
|
|
@ -14,12 +14,12 @@
|
|||
|
||||
import argparse
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from parakeet.frontend.cn_frontend import Frontend as cnFrontend
|
||||
from parakeet.utils.error_rate import wer
|
||||
from praatio import tgio
|
||||
from parakeet.utils.error_rate import word_errors
|
||||
|
||||
SILENCE_TOKENS = {"sp", "sil", "sp1", "spl"}
|
||||
|
||||
|
||||
def text_cleaner(raw_text):
|
||||
|
@ -29,80 +29,68 @@ def text_cleaner(raw_text):
|
|||
return text
|
||||
|
||||
|
||||
def get_baker_data(root_dir):
|
||||
alignment_files = sorted(
|
||||
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
|
||||
text_file = Path(text_file).expanduser()
|
||||
data_dict = defaultdict(dict)
|
||||
# filter out several files that have errors in annotation
|
||||
exclude = {'000611', '000662', '002365', '005107'}
|
||||
alignment_files = [f for f in alignment_files if f.stem not in exclude]
|
||||
# biaobei 前后有 sil ,中间没有 sp
|
||||
data_dict = defaultdict(dict)
|
||||
for alignment_fp in alignment_files:
|
||||
alignment = tgio.openTextgrid(alignment_fp)
|
||||
# only with baker's annotation
|
||||
utt_id = alignment.tierNameList[0].split(".")[0]
|
||||
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||
phones = []
|
||||
for interval in intervals:
|
||||
label = interval.label
|
||||
# Baker has sp1 rather than sp
|
||||
label = label.replace("sp1", "sp")
|
||||
phones.append(label)
|
||||
data_dict[utt_id]["phones"] = phones
|
||||
for line in open(text_file, "r"):
|
||||
if line.startswith("0"):
|
||||
utt_id, raw_text = line.strip().split()
|
||||
text = text_cleaner(raw_text)
|
||||
if utt_id in data_dict:
|
||||
data_dict[utt_id]['text'] = text
|
||||
else:
|
||||
pinyin = line.strip().split()
|
||||
if utt_id in data_dict:
|
||||
data_dict[utt_id]['pinyin'] = pinyin
|
||||
return data_dict
|
||||
|
||||
|
||||
def get_g2p_phones(data_dict, frontend):
|
||||
for utt_id in data_dict:
|
||||
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
|
||||
data_dict[utt_id]["g2p_phones"] = g2p_phones
|
||||
return data_dict
|
||||
|
||||
|
||||
def get_avg_wer(data_dict):
|
||||
wer_list = []
|
||||
for utt_id in data_dict:
|
||||
g2p_phones = data_dict[utt_id]['g2p_phones']
|
||||
# delete silence tokens in predicted phones
|
||||
g2p_phones = [phn for phn in g2p_phones if phn not in {"sp", "sil"}]
|
||||
gt_phones = data_dict[utt_id]['phones']
|
||||
# delete silence tokens in baker phones
|
||||
gt_phones = [phn for phn in gt_phones if phn not in {"sp", "sil"}]
|
||||
def get_avg_wer(raw_dict, ref_dict, frontend, output_dir):
|
||||
edit_distances = []
|
||||
ref_lens = []
|
||||
wf_g2p = open(output_dir / "text.g2p", "w")
|
||||
wf_ref = open(output_dir / "text.ref.clean", "w")
|
||||
for utt_id in raw_dict:
|
||||
if utt_id not in ref_dict:
|
||||
continue
|
||||
raw_text = raw_dict[utt_id]
|
||||
text = text_cleaner(raw_text)
|
||||
g2p_phones = frontend.get_phonemes(text)
|
||||
gt_phones = ref_dict[utt_id].split(" ")
|
||||
# delete silence tokens in predicted phones and ground truth phones
|
||||
g2p_phones = [phn for phn in g2p_phones if phn not in SILENCE_TOKENS]
|
||||
gt_phones = [phn for phn in gt_phones if phn not in SILENCE_TOKENS]
|
||||
gt_phones = " ".join(gt_phones)
|
||||
g2p_phones = " ".join(g2p_phones)
|
||||
single_wer = wer(gt_phones, g2p_phones)
|
||||
wer_list.append(single_wer)
|
||||
return sum(wer_list) / len(wer_list)
|
||||
wf_ref.write(utt_id + " " + gt_phones + "\n")
|
||||
wf_g2p.write(utt_id + " " + g2p_phones + "\n")
|
||||
edit_distance, ref_len = word_errors(gt_phones, g2p_phones)
|
||||
edit_distances.append(edit_distance)
|
||||
ref_lens.append(ref_len)
|
||||
|
||||
return sum(edit_distances) / sum(ref_lens)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="g2p example.")
|
||||
parser.add_argument(
|
||||
"--root-dir",
|
||||
default=None,
|
||||
"--input-dir",
|
||||
default="data/g2p",
|
||||
type=str,
|
||||
help="directory to baker dataset.")
|
||||
help="directory to preprocessed test data.")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="exp/g2p",
|
||||
type=str,
|
||||
help="directory to save g2p results.")
|
||||
|
||||
args = parser.parse_args()
|
||||
root_dir = Path(args.root_dir).expanduser()
|
||||
assert root_dir.is_dir()
|
||||
input_dir = Path(args.input_dir).expanduser()
|
||||
output_dir = Path(args.output_dir).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
assert input_dir.is_dir()
|
||||
raw_dict, ref_dict = dict(), dict()
|
||||
raw_path = input_dir / "text"
|
||||
ref_path = input_dir / "text.ref"
|
||||
|
||||
with open(raw_path, "r") as rf:
|
||||
for line in rf:
|
||||
line = line.strip()
|
||||
line_list = line.split(" ")
|
||||
utt_id, raw_text = line_list[0], " ".join(line_list[1:])
|
||||
raw_dict[utt_id] = raw_text
|
||||
with open(ref_path, "r") as rf:
|
||||
for line in rf:
|
||||
line = line.strip()
|
||||
line_list = line.split(" ")
|
||||
utt_id, phones = line_list[0], " ".join(line_list[1:])
|
||||
ref_dict[utt_id] = phones
|
||||
frontend = cnFrontend()
|
||||
data_dict = get_baker_data(root_dir)
|
||||
data_dict = get_g2p_phones(data_dict, frontend)
|
||||
avg_wer = get_avg_wer(data_dict)
|
||||
avg_wer = get_avg_wer(raw_dict, ref_dict, frontend, output_dir)
|
||||
print("The avg WER of g2p is:", avg_wer)
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import re
|
|||
from pathlib import Path
|
||||
|
||||
from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer
|
||||
from parakeet.utils.error_rate import cer
|
||||
from parakeet.utils.error_rate import char_errors
|
||||
|
||||
|
||||
# delete english characters
|
||||
|
@ -29,31 +29,67 @@ def del_en_add_space(input: str):
|
|||
return output
|
||||
|
||||
|
||||
def get_avg_cer(test_file, text_normalizer):
|
||||
cer_list = []
|
||||
for line in open(test_file, "r"):
|
||||
line = line.strip()
|
||||
raw_text, gt_text = line.split("|")
|
||||
def get_avg_cer(raw_dict, ref_dict, text_normalizer, output_dir):
|
||||
edit_distances = []
|
||||
ref_lens = []
|
||||
wf_ref = open(output_dir / "text.ref.clean", "w")
|
||||
wf_tn = open(output_dir / "text.tn", "w")
|
||||
for text_id in raw_dict:
|
||||
if text_id not in ref_dict:
|
||||
continue
|
||||
raw_text = raw_dict[text_id]
|
||||
gt_text = ref_dict[text_id]
|
||||
textnorm_text = text_normalizer.normalize_sentence(raw_text)
|
||||
|
||||
gt_text = del_en_add_space(gt_text)
|
||||
textnorm_text = del_en_add_space(textnorm_text)
|
||||
single_cer = cer(gt_text, textnorm_text)
|
||||
cer_list.append(single_cer)
|
||||
return sum(cer_list) / len(cer_list)
|
||||
wf_ref.write(text_id + " " + gt_text + "\n")
|
||||
wf_tn.write(text_id + " " + textnorm_text + "\n")
|
||||
edit_distance, ref_len = char_errors(gt_text, textnorm_text)
|
||||
edit_distances.append(edit_distance)
|
||||
ref_lens.append(ref_len)
|
||||
|
||||
return sum(edit_distances) / sum(ref_lens)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="text normalization example.")
|
||||
parser.add_argument(
|
||||
"--test-file",
|
||||
default=None,
|
||||
"--input-dir",
|
||||
default="data/textnorm",
|
||||
type=str,
|
||||
help="path of text normalization test file.")
|
||||
help="directory to preprocessed test data.")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="exp/textnorm",
|
||||
type=str,
|
||||
help="directory to save textnorm results.")
|
||||
|
||||
args = parser.parse_args()
|
||||
test_file = Path(args.test_file).expanduser()
|
||||
input_dir = Path(args.input_dir).expanduser()
|
||||
output_dir = Path(args.output_dir).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
assert input_dir.is_dir()
|
||||
raw_dict, ref_dict = dict(), dict()
|
||||
raw_path = input_dir / "text"
|
||||
ref_path = input_dir / "text.ref"
|
||||
|
||||
with open(raw_path, "r") as rf:
|
||||
for line in rf:
|
||||
line = line.strip()
|
||||
line_list = line.split(" ")
|
||||
text_id, raw_text = line_list[0], " ".join(line_list[1:])
|
||||
raw_dict[text_id] = raw_text
|
||||
with open(ref_path, "r") as rf:
|
||||
for line in rf:
|
||||
line = line.strip()
|
||||
line_list = line.split(" ")
|
||||
text_id, normed_text = line_list[0], " ".join(line_list[1:])
|
||||
ref_dict[text_id] = normed_text
|
||||
|
||||
text_normalizer = TextNormalizer()
|
||||
avg_cer = get_avg_cer(test_file, text_normalizer)
|
||||
|
||||
avg_cer = get_avg_cer(raw_dict, ref_dict, text_normalizer, output_dir)
|
||||
print("The avg CER of text normalization is:", avg_cer)
|
||||
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ def replace_percentage(match: re.Match) -> str:
|
|||
|
||||
|
||||
# 整数表达式
|
||||
# 带负号或者不带负号的整数 12, -10
|
||||
# 带负号的整数 -10
|
||||
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
||||
|
||||
|
||||
|
@ -116,7 +116,7 @@ def replace_number(match: re.Match) -> str:
|
|||
|
||||
# 范围表达式
|
||||
# 12-23, 12~23
|
||||
RE_RANGE = re.compile(r'(\d+)[~](\d+)')
|
||||
RE_RANGE = re.compile(r'(\d+)[-~](\d+)')
|
||||
|
||||
|
||||
def replace_range(match: re.Match) -> str:
|
||||
|
|
|
@ -55,11 +55,11 @@ class TextNormalizer():
|
|||
sentence = RE_DATE2.sub(replace_date2, sentence)
|
||||
sentence = RE_TIME.sub(replace_time, sentence)
|
||||
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
|
||||
sentence = RE_RANGE.sub(replace_range, sentence)
|
||||
sentence = RE_FRAC.sub(replace_frac, sentence)
|
||||
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
||||
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
|
||||
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
||||
sentence = RE_RANGE.sub(replace_range, sentence)
|
||||
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
||||
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
||||
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
||||
|
|
Loading…
Reference in New Issue