restructure frontend example

This commit is contained in:
TianYuan 2021-08-16 08:31:37 +00:00
parent 309228ddbf
commit e8991c973c
10 changed files with 269 additions and 90 deletions

View File

@ -61,14 +61,18 @@ class Frontend():
phone = match.group(1)
tone = match.group(2)
# if the merged erhua not in the vocab
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, we split 'iaor' into ['iao','er']
# and the tones accordingly change from ['3'] to ['3','2'], while '2' is the tone of 'er2'
if len(phone) >= 2 and phone != "er" and phone[
-1] == 'r' and phone not in self.vocab_phones and phone[:
-1] in self.vocab_phones:
phones.append(phone[:-1])
phones.append("er")
else:
tones.append(tone)
tones.append("2")
else:
phones.append(phone)
tones.append(tone)
else:
phones.append(full_phone)
tones.append('0')
@ -76,9 +80,11 @@ class Frontend():
tone_ids = paddle.to_tensor(tone_ids)
result["tone_ids"] = tone_ids
else:
# if the merged erhua not in the vocab
phones = []
for phone in phonemes:
# if the merged erhua not in the vocab
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2']
if len(phone) >= 3 and phone[:-1] != "er" and phone[
-2] == 'r' and phone not in self.vocab_phones and (
phone[:-2] + phone[-1]) in self.vocab_phones:

View File

@ -15,6 +15,6 @@ Run the command below to get the results of test.
```bash
./run.sh
```
The `avg WER` of g2p is: 0.02785753389811866
The `avg WER` of g2p is: 0.027124048652822204
The `avg CER` of text normalization is: 0.014229233983486172
The `avg CER` of text normalization is: 0.0061629764893859846

View File

@ -120,4 +120,6 @@ iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这
今年有望超三百亿美元|今年有望超三百亿美元
就连一向看多的任志强|就连一向看多的任志强
近期也一反常态地发表看空言论|近期也一反常态地发表看空言论
985|九八五
985|九八五
12~23|十二到二十三
12-23|十二到二十三

View File

@ -0,0 +1,90 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from collections import defaultdict
from pathlib import Path
from praatio import tgio
def get_baker_data(root_dir):
alignment_files = sorted(
list((root_dir / "PhoneLabeling").rglob("*.interval")))
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
text_file = Path(text_file).expanduser()
# filter out several files that have errors in annotation
exclude = {'000611', '000662', '002365', '005107'}
alignment_files = [f for f in alignment_files if f.stem not in exclude]
data_dict = defaultdict(dict)
for alignment_fp in alignment_files:
alignment = tgio.openTextgrid(alignment_fp)
# only with baker's annotation
utt_id = alignment.tierNameList[0].split(".")[0]
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
phones = []
for interval in intervals:
label = interval.label
phones.append(label)
data_dict[utt_id]["phones"] = phones
for line in open(text_file, "r"):
if line.startswith("0"):
utt_id, raw_text = line.strip().split()
if utt_id in data_dict:
data_dict[utt_id]['text'] = raw_text
else:
pinyin = line.strip().split()
if utt_id in data_dict:
data_dict[utt_id]['pinyin'] = pinyin
return data_dict
def get_g2p_phones(data_dict, frontend):
for utt_id in data_dict:
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
data_dict[utt_id]["g2p_phones"] = g2p_phones
return data_dict
def main():
parser = argparse.ArgumentParser(description="g2p example.")
parser.add_argument(
"--root-dir",
default=None,
type=str,
help="directory to baker dataset.")
parser.add_argument(
"--output-dir",
default="data/g2p",
type=str,
help="directory to output.")
args = parser.parse_args()
root_dir = Path(args.root_dir).expanduser()
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
assert root_dir.is_dir()
data_dict = get_baker_data(root_dir)
raw_path = output_dir / "text"
ref_path = output_dir / "text.ref"
wf_raw = open(raw_path, "w")
wf_ref = open(ref_path, "w")
for utt_id in data_dict:
wf_raw.write(utt_id + " " + data_dict[utt_id]['text'] + "\n")
wf_ref.write(utt_id + " " + " ".join(data_dict[utt_id]['phones']) +
"\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,51 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description="text normalization example.")
parser.add_argument(
"--test-file",
default="data/textnorm_test_cases.txt",
type=str,
help="path of text normalization test file.")
parser.add_argument(
"--output-dir",
default="data/textnorm",
type=str,
help="directory to output.")
args = parser.parse_args()
test_file = Path(args.test_file).expanduser()
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
raw_path = output_dir / "text"
ref_path = output_dir / "text.ref"
wf_raw = open(raw_path, "w")
wf_ref = open(ref_path, "w")
with open(test_file, "r") as rf:
for i, line in enumerate(rf):
raw_text, normed_text = line.strip().split("|")
wf_raw.write("utt_" + str(i) + " " + raw_text + "\n")
wf_ref.write("utt_" + str(i) + " " + normed_text + "\n")
if __name__ == "__main__":
main()

View File

@ -1,8 +1,14 @@
#!/bin/bash
# test g2p
echo "Start get g2p test data."
python3 get_g2p_data.py --root-dir=~/datasets/BZNSYP --output-dir=data/g2p
echo "Start test g2p."
python3 test_g2p.py --root-dir=~/datasets/BZNSYP
python3 test_g2p.py --input-dir=data/g2p --output-dir=exp/g2p
# test text normalization
echo "Start get text normalization test data."
python3 get_textnorm_data.py --test-file=data/textnorm_test_cases.txt --output-dir=data/textnorm
echo "Start test text normalization."
python3 test_textnorm.py --test-file=data/textnorm_test_cases.txt
python3 test_textnorm.py --input-dir=data/textnorm --output-dir=exp/textnorm

View File

@ -14,12 +14,12 @@
import argparse
import re
from collections import defaultdict
from pathlib import Path
from parakeet.frontend.cn_frontend import Frontend as cnFrontend
from parakeet.utils.error_rate import wer
from praatio import tgio
from parakeet.utils.error_rate import word_errors
SILENCE_TOKENS = {"sp", "sil", "sp1", "spl"}
def text_cleaner(raw_text):
@ -29,80 +29,68 @@ def text_cleaner(raw_text):
return text
def get_baker_data(root_dir):
alignment_files = sorted(
list((root_dir / "PhoneLabeling").rglob("*.interval")))
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
text_file = Path(text_file).expanduser()
data_dict = defaultdict(dict)
# filter out several files that have errors in annotation
exclude = {'000611', '000662', '002365', '005107'}
alignment_files = [f for f in alignment_files if f.stem not in exclude]
# biaobei 前后有 sil ,中间没有 sp
data_dict = defaultdict(dict)
for alignment_fp in alignment_files:
alignment = tgio.openTextgrid(alignment_fp)
# only with baker's annotation
utt_id = alignment.tierNameList[0].split(".")[0]
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
phones = []
for interval in intervals:
label = interval.label
# Baker has sp1 rather than sp
label = label.replace("sp1", "sp")
phones.append(label)
data_dict[utt_id]["phones"] = phones
for line in open(text_file, "r"):
if line.startswith("0"):
utt_id, raw_text = line.strip().split()
text = text_cleaner(raw_text)
if utt_id in data_dict:
data_dict[utt_id]['text'] = text
else:
pinyin = line.strip().split()
if utt_id in data_dict:
data_dict[utt_id]['pinyin'] = pinyin
return data_dict
def get_g2p_phones(data_dict, frontend):
for utt_id in data_dict:
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
data_dict[utt_id]["g2p_phones"] = g2p_phones
return data_dict
def get_avg_wer(data_dict):
wer_list = []
for utt_id in data_dict:
g2p_phones = data_dict[utt_id]['g2p_phones']
# delete silence tokens in predicted phones
g2p_phones = [phn for phn in g2p_phones if phn not in {"sp", "sil"}]
gt_phones = data_dict[utt_id]['phones']
# delete silence tokens in baker phones
gt_phones = [phn for phn in gt_phones if phn not in {"sp", "sil"}]
def get_avg_wer(raw_dict, ref_dict, frontend, output_dir):
edit_distances = []
ref_lens = []
wf_g2p = open(output_dir / "text.g2p", "w")
wf_ref = open(output_dir / "text.ref.clean", "w")
for utt_id in raw_dict:
if utt_id not in ref_dict:
continue
raw_text = raw_dict[utt_id]
text = text_cleaner(raw_text)
g2p_phones = frontend.get_phonemes(text)
gt_phones = ref_dict[utt_id].split(" ")
# delete silence tokens in predicted phones and ground truth phones
g2p_phones = [phn for phn in g2p_phones if phn not in SILENCE_TOKENS]
gt_phones = [phn for phn in gt_phones if phn not in SILENCE_TOKENS]
gt_phones = " ".join(gt_phones)
g2p_phones = " ".join(g2p_phones)
single_wer = wer(gt_phones, g2p_phones)
wer_list.append(single_wer)
return sum(wer_list) / len(wer_list)
wf_ref.write(utt_id + " " + gt_phones + "\n")
wf_g2p.write(utt_id + " " + g2p_phones + "\n")
edit_distance, ref_len = word_errors(gt_phones, g2p_phones)
edit_distances.append(edit_distance)
ref_lens.append(ref_len)
return sum(edit_distances) / sum(ref_lens)
def main():
parser = argparse.ArgumentParser(description="g2p example.")
parser.add_argument(
"--root-dir",
default=None,
"--input-dir",
default="data/g2p",
type=str,
help="directory to baker dataset.")
help="directory to preprocessed test data.")
parser.add_argument(
"--output-dir",
default="exp/g2p",
type=str,
help="directory to save g2p results.")
args = parser.parse_args()
root_dir = Path(args.root_dir).expanduser()
assert root_dir.is_dir()
input_dir = Path(args.input_dir).expanduser()
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
assert input_dir.is_dir()
raw_dict, ref_dict = dict(), dict()
raw_path = input_dir / "text"
ref_path = input_dir / "text.ref"
with open(raw_path, "r") as rf:
for line in rf:
line = line.strip()
line_list = line.split(" ")
utt_id, raw_text = line_list[0], " ".join(line_list[1:])
raw_dict[utt_id] = raw_text
with open(ref_path, "r") as rf:
for line in rf:
line = line.strip()
line_list = line.split(" ")
utt_id, phones = line_list[0], " ".join(line_list[1:])
ref_dict[utt_id] = phones
frontend = cnFrontend()
data_dict = get_baker_data(root_dir)
data_dict = get_g2p_phones(data_dict, frontend)
avg_wer = get_avg_wer(data_dict)
avg_wer = get_avg_wer(raw_dict, ref_dict, frontend, output_dir)
print("The avg WER of g2p is:", avg_wer)

View File

@ -17,7 +17,7 @@ import re
from pathlib import Path
from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer
from parakeet.utils.error_rate import cer
from parakeet.utils.error_rate import char_errors
# delete english characters
@ -29,31 +29,67 @@ def del_en_add_space(input: str):
return output
def get_avg_cer(test_file, text_normalizer):
cer_list = []
for line in open(test_file, "r"):
line = line.strip()
raw_text, gt_text = line.split("|")
def get_avg_cer(raw_dict, ref_dict, text_normalizer, output_dir):
edit_distances = []
ref_lens = []
wf_ref = open(output_dir / "text.ref.clean", "w")
wf_tn = open(output_dir / "text.tn", "w")
for text_id in raw_dict:
if text_id not in ref_dict:
continue
raw_text = raw_dict[text_id]
gt_text = ref_dict[text_id]
textnorm_text = text_normalizer.normalize_sentence(raw_text)
gt_text = del_en_add_space(gt_text)
textnorm_text = del_en_add_space(textnorm_text)
single_cer = cer(gt_text, textnorm_text)
cer_list.append(single_cer)
return sum(cer_list) / len(cer_list)
wf_ref.write(text_id + " " + gt_text + "\n")
wf_tn.write(text_id + " " + textnorm_text + "\n")
edit_distance, ref_len = char_errors(gt_text, textnorm_text)
edit_distances.append(edit_distance)
ref_lens.append(ref_len)
return sum(edit_distances) / sum(ref_lens)
def main():
parser = argparse.ArgumentParser(description="text normalization example.")
parser.add_argument(
"--test-file",
default=None,
"--input-dir",
default="data/textnorm",
type=str,
help="path of text normalization test file.")
help="directory to preprocessed test data.")
parser.add_argument(
"--output-dir",
default="exp/textnorm",
type=str,
help="directory to save textnorm results.")
args = parser.parse_args()
test_file = Path(args.test_file).expanduser()
input_dir = Path(args.input_dir).expanduser()
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
assert input_dir.is_dir()
raw_dict, ref_dict = dict(), dict()
raw_path = input_dir / "text"
ref_path = input_dir / "text.ref"
with open(raw_path, "r") as rf:
for line in rf:
line = line.strip()
line_list = line.split(" ")
text_id, raw_text = line_list[0], " ".join(line_list[1:])
raw_dict[text_id] = raw_text
with open(ref_path, "r") as rf:
for line in rf:
line = line.strip()
line_list = line.split(" ")
text_id, normed_text = line_list[0], " ".join(line_list[1:])
ref_dict[text_id] = normed_text
text_normalizer = TextNormalizer()
avg_cer = get_avg_cer(test_file, text_normalizer)
avg_cer = get_avg_cer(raw_dict, ref_dict, text_normalizer, output_dir)
print("The avg CER of text normalization is:", avg_cer)

View File

@ -60,7 +60,7 @@ def replace_percentage(match: re.Match) -> str:
# 整数表达式
# 带负号或者不带负号的整数 12, -10
# 带负号的整数 -10
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
@ -116,7 +116,7 @@ def replace_number(match: re.Match) -> str:
# 范围表达式
# 12-23, 12~23
RE_RANGE = re.compile(r'(\d+)[~](\d+)')
RE_RANGE = re.compile(r'(\d+)[-~](\d+)')
def replace_range(match: re.Match) -> str:

View File

@ -55,11 +55,11 @@ class TextNormalizer():
sentence = RE_DATE2.sub(replace_date2, sentence)
sentence = RE_TIME.sub(replace_time, sentence)
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
sentence = RE_RANGE.sub(replace_range, sentence)
sentence = RE_FRAC.sub(replace_frac, sentence)
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
sentence = RE_RANGE.sub(replace_range, sentence)
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,