restructure frontend example
This commit is contained in:
parent
309228ddbf
commit
e8991c973c
|
@ -61,14 +61,18 @@ class Frontend():
|
||||||
phone = match.group(1)
|
phone = match.group(1)
|
||||||
tone = match.group(2)
|
tone = match.group(2)
|
||||||
# if the merged erhua not in the vocab
|
# if the merged erhua not in the vocab
|
||||||
|
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, we split 'iaor' into ['iao','er']
|
||||||
|
# and the tones accordingly change from ['3'] to ['3','2'], while '2' is the tone of 'er2'
|
||||||
if len(phone) >= 2 and phone != "er" and phone[
|
if len(phone) >= 2 and phone != "er" and phone[
|
||||||
-1] == 'r' and phone not in self.vocab_phones and phone[:
|
-1] == 'r' and phone not in self.vocab_phones and phone[:
|
||||||
-1] in self.vocab_phones:
|
-1] in self.vocab_phones:
|
||||||
phones.append(phone[:-1])
|
phones.append(phone[:-1])
|
||||||
phones.append("er")
|
phones.append("er")
|
||||||
else:
|
|
||||||
tones.append(tone)
|
tones.append(tone)
|
||||||
tones.append("2")
|
tones.append("2")
|
||||||
|
else:
|
||||||
|
phones.append(phone)
|
||||||
|
tones.append(tone)
|
||||||
else:
|
else:
|
||||||
phones.append(full_phone)
|
phones.append(full_phone)
|
||||||
tones.append('0')
|
tones.append('0')
|
||||||
|
@ -76,9 +80,11 @@ class Frontend():
|
||||||
tone_ids = paddle.to_tensor(tone_ids)
|
tone_ids = paddle.to_tensor(tone_ids)
|
||||||
result["tone_ids"] = tone_ids
|
result["tone_ids"] = tone_ids
|
||||||
else:
|
else:
|
||||||
# if the merged erhua not in the vocab
|
|
||||||
phones = []
|
phones = []
|
||||||
for phone in phonemes:
|
for phone in phonemes:
|
||||||
|
# if the merged erhua not in the vocab
|
||||||
|
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2']
|
||||||
if len(phone) >= 3 and phone[:-1] != "er" and phone[
|
if len(phone) >= 3 and phone[:-1] != "er" and phone[
|
||||||
-2] == 'r' and phone not in self.vocab_phones and (
|
-2] == 'r' and phone not in self.vocab_phones and (
|
||||||
phone[:-2] + phone[-1]) in self.vocab_phones:
|
phone[:-2] + phone[-1]) in self.vocab_phones:
|
||||||
|
|
|
@ -15,6 +15,6 @@ Run the command below to get the results of test.
|
||||||
```bash
|
```bash
|
||||||
./run.sh
|
./run.sh
|
||||||
```
|
```
|
||||||
The `avg WER` of g2p is: 0.02785753389811866
|
The `avg WER` of g2p is: 0.027124048652822204
|
||||||
|
|
||||||
The `avg CER` of text normalization is: 0.014229233983486172
|
The `avg CER` of text normalization is: 0.0061629764893859846
|
||||||
|
|
|
@ -121,3 +121,5 @@ iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这
|
||||||
就连一向看多的任志强|就连一向看多的任志强
|
就连一向看多的任志强|就连一向看多的任志强
|
||||||
近期也一反常态地发表看空言论|近期也一反常态地发表看空言论
|
近期也一反常态地发表看空言论|近期也一反常态地发表看空言论
|
||||||
985|九八五
|
985|九八五
|
||||||
|
12~23|十二到二十三
|
||||||
|
12-23|十二到二十三
|
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from praatio import tgio
|
||||||
|
|
||||||
|
|
||||||
|
def get_baker_data(root_dir):
|
||||||
|
alignment_files = sorted(
|
||||||
|
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||||
|
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
|
||||||
|
text_file = Path(text_file).expanduser()
|
||||||
|
# filter out several files that have errors in annotation
|
||||||
|
exclude = {'000611', '000662', '002365', '005107'}
|
||||||
|
alignment_files = [f for f in alignment_files if f.stem not in exclude]
|
||||||
|
data_dict = defaultdict(dict)
|
||||||
|
for alignment_fp in alignment_files:
|
||||||
|
alignment = tgio.openTextgrid(alignment_fp)
|
||||||
|
# only with baker's annotation
|
||||||
|
utt_id = alignment.tierNameList[0].split(".")[0]
|
||||||
|
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||||
|
phones = []
|
||||||
|
for interval in intervals:
|
||||||
|
label = interval.label
|
||||||
|
phones.append(label)
|
||||||
|
data_dict[utt_id]["phones"] = phones
|
||||||
|
for line in open(text_file, "r"):
|
||||||
|
if line.startswith("0"):
|
||||||
|
utt_id, raw_text = line.strip().split()
|
||||||
|
if utt_id in data_dict:
|
||||||
|
data_dict[utt_id]['text'] = raw_text
|
||||||
|
else:
|
||||||
|
pinyin = line.strip().split()
|
||||||
|
if utt_id in data_dict:
|
||||||
|
data_dict[utt_id]['pinyin'] = pinyin
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_g2p_phones(data_dict, frontend):
|
||||||
|
for utt_id in data_dict:
|
||||||
|
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
|
||||||
|
data_dict[utt_id]["g2p_phones"] = g2p_phones
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="g2p example.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="directory to baker dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="data/g2p",
|
||||||
|
type=str,
|
||||||
|
help="directory to output.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
root_dir = Path(args.root_dir).expanduser()
|
||||||
|
output_dir = Path(args.output_dir).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
assert root_dir.is_dir()
|
||||||
|
data_dict = get_baker_data(root_dir)
|
||||||
|
raw_path = output_dir / "text"
|
||||||
|
ref_path = output_dir / "text.ref"
|
||||||
|
wf_raw = open(raw_path, "w")
|
||||||
|
wf_ref = open(ref_path, "w")
|
||||||
|
for utt_id in data_dict:
|
||||||
|
wf_raw.write(utt_id + " " + data_dict[utt_id]['text'] + "\n")
|
||||||
|
wf_ref.write(utt_id + " " + " ".join(data_dict[utt_id]['phones']) +
|
||||||
|
"\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="text normalization example.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-file",
|
||||||
|
default="data/textnorm_test_cases.txt",
|
||||||
|
type=str,
|
||||||
|
help="path of text normalization test file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="data/textnorm",
|
||||||
|
type=str,
|
||||||
|
help="directory to output.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
test_file = Path(args.test_file).expanduser()
|
||||||
|
output_dir = Path(args.output_dir).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
raw_path = output_dir / "text"
|
||||||
|
ref_path = output_dir / "text.ref"
|
||||||
|
wf_raw = open(raw_path, "w")
|
||||||
|
wf_ref = open(ref_path, "w")
|
||||||
|
|
||||||
|
with open(test_file, "r") as rf:
|
||||||
|
for i, line in enumerate(rf):
|
||||||
|
raw_text, normed_text = line.strip().split("|")
|
||||||
|
wf_raw.write("utt_" + str(i) + " " + raw_text + "\n")
|
||||||
|
wf_ref.write("utt_" + str(i) + " " + normed_text + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,8 +1,14 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# test g2p
|
# test g2p
|
||||||
|
echo "Start get g2p test data."
|
||||||
|
python3 get_g2p_data.py --root-dir=~/datasets/BZNSYP --output-dir=data/g2p
|
||||||
echo "Start test g2p."
|
echo "Start test g2p."
|
||||||
python3 test_g2p.py --root-dir=~/datasets/BZNSYP
|
python3 test_g2p.py --input-dir=data/g2p --output-dir=exp/g2p
|
||||||
|
|
||||||
# test text normalization
|
# test text normalization
|
||||||
|
echo "Start get text normalization test data."
|
||||||
|
python3 get_textnorm_data.py --test-file=data/textnorm_test_cases.txt --output-dir=data/textnorm
|
||||||
echo "Start test text normalization."
|
echo "Start test text normalization."
|
||||||
python3 test_textnorm.py --test-file=data/textnorm_test_cases.txt
|
python3 test_textnorm.py --input-dir=data/textnorm --output-dir=exp/textnorm
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,12 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from parakeet.frontend.cn_frontend import Frontend as cnFrontend
|
from parakeet.frontend.cn_frontend import Frontend as cnFrontend
|
||||||
from parakeet.utils.error_rate import wer
|
from parakeet.utils.error_rate import word_errors
|
||||||
from praatio import tgio
|
|
||||||
|
SILENCE_TOKENS = {"sp", "sil", "sp1", "spl"}
|
||||||
|
|
||||||
|
|
||||||
def text_cleaner(raw_text):
|
def text_cleaner(raw_text):
|
||||||
|
@ -29,80 +29,68 @@ def text_cleaner(raw_text):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def get_baker_data(root_dir):
|
def get_avg_wer(raw_dict, ref_dict, frontend, output_dir):
|
||||||
alignment_files = sorted(
|
edit_distances = []
|
||||||
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
ref_lens = []
|
||||||
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
|
wf_g2p = open(output_dir / "text.g2p", "w")
|
||||||
text_file = Path(text_file).expanduser()
|
wf_ref = open(output_dir / "text.ref.clean", "w")
|
||||||
data_dict = defaultdict(dict)
|
for utt_id in raw_dict:
|
||||||
# filter out several files that have errors in annotation
|
if utt_id not in ref_dict:
|
||||||
exclude = {'000611', '000662', '002365', '005107'}
|
continue
|
||||||
alignment_files = [f for f in alignment_files if f.stem not in exclude]
|
raw_text = raw_dict[utt_id]
|
||||||
# biaobei 前后有 sil ,中间没有 sp
|
text = text_cleaner(raw_text)
|
||||||
data_dict = defaultdict(dict)
|
g2p_phones = frontend.get_phonemes(text)
|
||||||
for alignment_fp in alignment_files:
|
gt_phones = ref_dict[utt_id].split(" ")
|
||||||
alignment = tgio.openTextgrid(alignment_fp)
|
# delete silence tokens in predicted phones and ground truth phones
|
||||||
# only with baker's annotation
|
g2p_phones = [phn for phn in g2p_phones if phn not in SILENCE_TOKENS]
|
||||||
utt_id = alignment.tierNameList[0].split(".")[0]
|
gt_phones = [phn for phn in gt_phones if phn not in SILENCE_TOKENS]
|
||||||
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
|
||||||
phones = []
|
|
||||||
for interval in intervals:
|
|
||||||
label = interval.label
|
|
||||||
# Baker has sp1 rather than sp
|
|
||||||
label = label.replace("sp1", "sp")
|
|
||||||
phones.append(label)
|
|
||||||
data_dict[utt_id]["phones"] = phones
|
|
||||||
for line in open(text_file, "r"):
|
|
||||||
if line.startswith("0"):
|
|
||||||
utt_id, raw_text = line.strip().split()
|
|
||||||
text = text_cleaner(raw_text)
|
|
||||||
if utt_id in data_dict:
|
|
||||||
data_dict[utt_id]['text'] = text
|
|
||||||
else:
|
|
||||||
pinyin = line.strip().split()
|
|
||||||
if utt_id in data_dict:
|
|
||||||
data_dict[utt_id]['pinyin'] = pinyin
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_g2p_phones(data_dict, frontend):
|
|
||||||
for utt_id in data_dict:
|
|
||||||
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
|
|
||||||
data_dict[utt_id]["g2p_phones"] = g2p_phones
|
|
||||||
return data_dict
|
|
||||||
|
|
||||||
|
|
||||||
def get_avg_wer(data_dict):
|
|
||||||
wer_list = []
|
|
||||||
for utt_id in data_dict:
|
|
||||||
g2p_phones = data_dict[utt_id]['g2p_phones']
|
|
||||||
# delete silence tokens in predicted phones
|
|
||||||
g2p_phones = [phn for phn in g2p_phones if phn not in {"sp", "sil"}]
|
|
||||||
gt_phones = data_dict[utt_id]['phones']
|
|
||||||
# delete silence tokens in baker phones
|
|
||||||
gt_phones = [phn for phn in gt_phones if phn not in {"sp", "sil"}]
|
|
||||||
gt_phones = " ".join(gt_phones)
|
gt_phones = " ".join(gt_phones)
|
||||||
g2p_phones = " ".join(g2p_phones)
|
g2p_phones = " ".join(g2p_phones)
|
||||||
single_wer = wer(gt_phones, g2p_phones)
|
wf_ref.write(utt_id + " " + gt_phones + "\n")
|
||||||
wer_list.append(single_wer)
|
wf_g2p.write(utt_id + " " + g2p_phones + "\n")
|
||||||
return sum(wer_list) / len(wer_list)
|
edit_distance, ref_len = word_errors(gt_phones, g2p_phones)
|
||||||
|
edit_distances.append(edit_distance)
|
||||||
|
ref_lens.append(ref_len)
|
||||||
|
|
||||||
|
return sum(edit_distances) / sum(ref_lens)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="g2p example.")
|
parser = argparse.ArgumentParser(description="g2p example.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root-dir",
|
"--input-dir",
|
||||||
default=None,
|
default="data/g2p",
|
||||||
type=str,
|
type=str,
|
||||||
help="directory to baker dataset.")
|
help="directory to preprocessed test data.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="exp/g2p",
|
||||||
|
type=str,
|
||||||
|
help="directory to save g2p results.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
root_dir = Path(args.root_dir).expanduser()
|
input_dir = Path(args.input_dir).expanduser()
|
||||||
assert root_dir.is_dir()
|
output_dir = Path(args.output_dir).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
assert input_dir.is_dir()
|
||||||
|
raw_dict, ref_dict = dict(), dict()
|
||||||
|
raw_path = input_dir / "text"
|
||||||
|
ref_path = input_dir / "text.ref"
|
||||||
|
|
||||||
|
with open(raw_path, "r") as rf:
|
||||||
|
for line in rf:
|
||||||
|
line = line.strip()
|
||||||
|
line_list = line.split(" ")
|
||||||
|
utt_id, raw_text = line_list[0], " ".join(line_list[1:])
|
||||||
|
raw_dict[utt_id] = raw_text
|
||||||
|
with open(ref_path, "r") as rf:
|
||||||
|
for line in rf:
|
||||||
|
line = line.strip()
|
||||||
|
line_list = line.split(" ")
|
||||||
|
utt_id, phones = line_list[0], " ".join(line_list[1:])
|
||||||
|
ref_dict[utt_id] = phones
|
||||||
frontend = cnFrontend()
|
frontend = cnFrontend()
|
||||||
data_dict = get_baker_data(root_dir)
|
avg_wer = get_avg_wer(raw_dict, ref_dict, frontend, output_dir)
|
||||||
data_dict = get_g2p_phones(data_dict, frontend)
|
|
||||||
avg_wer = get_avg_wer(data_dict)
|
|
||||||
print("The avg WER of g2p is:", avg_wer)
|
print("The avg WER of g2p is:", avg_wer)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer
|
from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer
|
||||||
from parakeet.utils.error_rate import cer
|
from parakeet.utils.error_rate import char_errors
|
||||||
|
|
||||||
|
|
||||||
# delete english characters
|
# delete english characters
|
||||||
|
@ -29,31 +29,67 @@ def del_en_add_space(input: str):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def get_avg_cer(test_file, text_normalizer):
|
def get_avg_cer(raw_dict, ref_dict, text_normalizer, output_dir):
|
||||||
cer_list = []
|
edit_distances = []
|
||||||
for line in open(test_file, "r"):
|
ref_lens = []
|
||||||
line = line.strip()
|
wf_ref = open(output_dir / "text.ref.clean", "w")
|
||||||
raw_text, gt_text = line.split("|")
|
wf_tn = open(output_dir / "text.tn", "w")
|
||||||
|
for text_id in raw_dict:
|
||||||
|
if text_id not in ref_dict:
|
||||||
|
continue
|
||||||
|
raw_text = raw_dict[text_id]
|
||||||
|
gt_text = ref_dict[text_id]
|
||||||
textnorm_text = text_normalizer.normalize_sentence(raw_text)
|
textnorm_text = text_normalizer.normalize_sentence(raw_text)
|
||||||
|
|
||||||
gt_text = del_en_add_space(gt_text)
|
gt_text = del_en_add_space(gt_text)
|
||||||
textnorm_text = del_en_add_space(textnorm_text)
|
textnorm_text = del_en_add_space(textnorm_text)
|
||||||
single_cer = cer(gt_text, textnorm_text)
|
wf_ref.write(text_id + " " + gt_text + "\n")
|
||||||
cer_list.append(single_cer)
|
wf_tn.write(text_id + " " + textnorm_text + "\n")
|
||||||
return sum(cer_list) / len(cer_list)
|
edit_distance, ref_len = char_errors(gt_text, textnorm_text)
|
||||||
|
edit_distances.append(edit_distance)
|
||||||
|
ref_lens.append(ref_len)
|
||||||
|
|
||||||
|
return sum(edit_distances) / sum(ref_lens)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="text normalization example.")
|
parser = argparse.ArgumentParser(description="text normalization example.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-file",
|
"--input-dir",
|
||||||
default=None,
|
default="data/textnorm",
|
||||||
type=str,
|
type=str,
|
||||||
help="path of text normalization test file.")
|
help="directory to preprocessed test data.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="exp/textnorm",
|
||||||
|
type=str,
|
||||||
|
help="directory to save textnorm results.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
test_file = Path(args.test_file).expanduser()
|
input_dir = Path(args.input_dir).expanduser()
|
||||||
|
output_dir = Path(args.output_dir).expanduser()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
assert input_dir.is_dir()
|
||||||
|
raw_dict, ref_dict = dict(), dict()
|
||||||
|
raw_path = input_dir / "text"
|
||||||
|
ref_path = input_dir / "text.ref"
|
||||||
|
|
||||||
|
with open(raw_path, "r") as rf:
|
||||||
|
for line in rf:
|
||||||
|
line = line.strip()
|
||||||
|
line_list = line.split(" ")
|
||||||
|
text_id, raw_text = line_list[0], " ".join(line_list[1:])
|
||||||
|
raw_dict[text_id] = raw_text
|
||||||
|
with open(ref_path, "r") as rf:
|
||||||
|
for line in rf:
|
||||||
|
line = line.strip()
|
||||||
|
line_list = line.split(" ")
|
||||||
|
text_id, normed_text = line_list[0], " ".join(line_list[1:])
|
||||||
|
ref_dict[text_id] = normed_text
|
||||||
|
|
||||||
text_normalizer = TextNormalizer()
|
text_normalizer = TextNormalizer()
|
||||||
avg_cer = get_avg_cer(test_file, text_normalizer)
|
|
||||||
|
avg_cer = get_avg_cer(raw_dict, ref_dict, text_normalizer, output_dir)
|
||||||
print("The avg CER of text normalization is:", avg_cer)
|
print("The avg CER of text normalization is:", avg_cer)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ def replace_percentage(match: re.Match) -> str:
|
||||||
|
|
||||||
|
|
||||||
# 整数表达式
|
# 整数表达式
|
||||||
# 带负号或者不带负号的整数 12, -10
|
# 带负号的整数 -10
|
||||||
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ def replace_number(match: re.Match) -> str:
|
||||||
|
|
||||||
# 范围表达式
|
# 范围表达式
|
||||||
# 12-23, 12~23
|
# 12-23, 12~23
|
||||||
RE_RANGE = re.compile(r'(\d+)[~](\d+)')
|
RE_RANGE = re.compile(r'(\d+)[-~](\d+)')
|
||||||
|
|
||||||
|
|
||||||
def replace_range(match: re.Match) -> str:
|
def replace_range(match: re.Match) -> str:
|
||||||
|
|
|
@ -55,11 +55,11 @@ class TextNormalizer():
|
||||||
sentence = RE_DATE2.sub(replace_date2, sentence)
|
sentence = RE_DATE2.sub(replace_date2, sentence)
|
||||||
sentence = RE_TIME.sub(replace_time, sentence)
|
sentence = RE_TIME.sub(replace_time, sentence)
|
||||||
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
|
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
|
||||||
sentence = RE_RANGE.sub(replace_range, sentence)
|
|
||||||
sentence = RE_FRAC.sub(replace_frac, sentence)
|
sentence = RE_FRAC.sub(replace_frac, sentence)
|
||||||
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
||||||
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
|
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
|
||||||
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
||||||
|
sentence = RE_RANGE.sub(replace_range, sentence)
|
||||||
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
||||||
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
||||||
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
||||||
|
|
Loading…
Reference in New Issue