diff --git a/examples/fastspeech2/baker/README.md b/examples/fastspeech2/baker/README.md index 339484a..dac6a40 100644 --- a/examples/fastspeech2/baker/README.md +++ b/examples/fastspeech2/baker/README.md @@ -1,5 +1,3 @@ - - # FastSpeech2 with BZNSYP ## Dataset diff --git a/examples/fastspeech2/baker/config.py b/examples/fastspeech2/baker/config.py index cc937b9..7cf3d95 100644 --- a/examples/fastspeech2/baker/config.py +++ b/examples/fastspeech2/baker/config.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from yacs.config import CfgNode as Configuration import yaml -with open("conf/default.yaml", 'rt') as f: +config_path = (Path(__file__).parent / "conf" / "default.yaml").resolve() + +with open(config_path, 'rt') as f: _C = yaml.safe_load(f) _C = Configuration(_C) diff --git a/examples/fastspeech2/baker/frontend.py b/examples/fastspeech2/baker/frontend.py index 8d2c1f1..b69d15d 100644 --- a/examples/fastspeech2/baker/frontend.py +++ b/examples/fastspeech2/baker/frontend.py @@ -58,8 +58,17 @@ class Frontend(): # split tone from finals match = re.match(r'^(\w+)([012345])$', full_phone) if match: - phones.append(match.group(1)) - tones.append(match.group(2)) + phone = match.group(1) + tone = match.group(2) + # if the merged erhua not in the vocab + if len(phone) >= 2 and phone != "er" and phone[ + -1] == 'r' and phone not in self.vocab_phones and phone[: + -1] in self.vocab_phones: + phones.append(phone[:-1]) + phones.append("er") + else: + tones.append(tone) + tones.append("2") else: phones.append(full_phone) tones.append('0') @@ -67,7 +76,17 @@ class Frontend(): tone_ids = paddle.to_tensor(tone_ids) result["tone_ids"] = tone_ids else: - phones = phonemes + # if the merged erhua not in the vocab + phones = [] + for phone in phonemes: + if len(phone) >= 3 and phone[:-1] != "er" and phone[ + -2] == 'r' and phone not in self.vocab_phones and ( + phone[:-2] + phone[-1]) in self.vocab_phones: + phones.append((phone[:-2] + phone[-1])) + phones.append("er2") + else: + phones.append(phone) + phone_ids = self._p2id(phones) phone_ids = paddle.to_tensor(phone_ids) result["phone_ids"] = phone_ids diff --git a/examples/fastspeech2/baker/preprocess.py b/examples/fastspeech2/baker/preprocess.py index c079715..40dadb1 100644 --- a/examples/fastspeech2/baker/preprocess.py +++ b/examples/fastspeech2/baker/preprocess.py @@ -32,12 +32,12 @@ def get_phn_dur(file_name): read MFA duration.txt Parameters ---------- - file_name : str or Path - path of gen_duration_from_textgrid.py's result + file_name : str or Path + path of gen_duration_from_textgrid.py's result Returns ---------- - Dict - sentence: {'utt': ([char], [int])} + Dict + sentence: {'utt': ([char], [int])} ''' f = open(file_name, 'r') sentence = {} @@ -58,8 +58,8 @@ def deal_silence(sentence): merge silences, set Parameters ---------- - sentence : Dict - sentence: {'utt': ([char], [int])} + sentence : Dict + sentence: {'utt': ([char], [int])} ''' for utt in sentence: cur_phn, cur_dur = sentence[utt] @@ -91,10 +91,10 @@ def get_input_token(sentence, output_path): get phone set from training data and save it Parameters ---------- - sentence : Dict - sentence: {'utt': ([char], [int])} - output_path : str or path - path to save phone_id_map + sentence : Dict + sentence: {'utt': ([char], [int])} + output_path : str or path + path to save phone_id_map ''' phn_token = set() for utt in sentence: @@ -117,12 +117,12 @@ def compare_duration_and_mel_length(sentences, utt, mel): check duration error, correct sentences[utt] if possible, else pop sentences[utt] Parameters ---------- - sentences : Dict - sentences[utt] = [phones_list ,durations_list] - utt : str - utt_id - mel : np.ndarry - features (num_frames, n_mels) + sentences : Dict + sentences[utt] = [phones_list ,durations_list] + utt : str + utt_id + mel : np.ndarry + features (num_frames, n_mels) ''' if utt in sentences: @@ -267,7 +267,7 @@ def main(): type=str, help="directory to baker dataset.") parser.add_argument( - "--dur-path", + "--dur-file", default=None, type=str, help="path to baker durations.txt.") @@ -308,8 +308,13 @@ def main(): root_dir = Path(args.rootdir).expanduser() dumpdir = Path(args.dumpdir).expanduser() dumpdir.mkdir(parents=True, exist_ok=True) + dur_file = Path(args.dur_file).expanduser() + + assert root_dir.is_dir() + assert dur_file.is_file() + + sentences = get_phn_dur(dur_file) - sentences = get_phn_dur(args.dur_path) deal_silence(sentences) phone_id_map_path = dumpdir / "phone_id_map.txt" get_input_token(sentences, phone_id_map_path) diff --git a/examples/fastspeech2/baker/preprocess.sh b/examples/fastspeech2/baker/preprocess.sh index 7247cab..e149b3f 100755 --- a/examples/fastspeech2/baker/preprocess.sh +++ b/examples/fastspeech2/baker/preprocess.sh @@ -4,7 +4,7 @@ python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt # extract features -python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-path durations.txt --num-cpu 4 --cut-sil True +python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-file durations.txt --num-cpu 4 --cut-sil True # # get features' stats(mean and std) python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech" diff --git a/examples/text_frontend/README.md b/examples/text_frontend/README.md new file mode 100644 index 0000000..f5d511e --- /dev/null +++ b/examples/text_frontend/README.md @@ -0,0 +1,20 @@ +# Chinese Text Frontend Example +Here's an example for Chinese text frontend, including g2p and text normalization. +## G2P +For g2p, we use BZNSYP's phone label as the ground truth and we delete silence tokens in labels and predicted phones. + +You should Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/index/source) and extract it. Assume the path to the dataset is `~/datasets/BZNSYP`. + +We use `WER` as evaluation criterion. +## Text Normalization +For text normalization, the test data is `data/textnorm_test_cases.txt`, we use `|` as the separator of raw_data and normed_data. + +We use `CER` as evaluation criterion. +## Start +Run the command below to get the results of test. +```bash +./run.sh +``` +The `avg WER` of g2p is: 0.02785753389811866 + +The `avg CER` of text normalization is: 0.014229233983486172 diff --git a/examples/text_frontend/data/textnorm_test_cases.txt b/examples/text_frontend/data/textnorm_test_cases.txt new file mode 100644 index 0000000..d06c9fa --- /dev/null +++ b/examples/text_frontend/data/textnorm_test_cases.txt @@ -0,0 +1,123 @@ +今天的最低气温达到-10°C.|今天的最低气温达到零下十度. +只要有33/4的人同意,就可以通过决议。|只要有四分之三十三的人同意,就可以通过决议。 +1945年5月2日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。|一九四五年五月二日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。 +4月16日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。|四月十六日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。 +如果剩下的30.6%是过去,那么还有69.4%.|如果剩下的百分之三十点六是过去,那么还有百分之六十九点四. +事情发生在2020/03/31的上午8:00.|事情发生在二零二零年三月三十一日的上午八点. +警方正在找一支.22口径的手枪。|警方正在找一支点二二口径的手枪。 +欢迎致电中国联通,北京2022年冬奥会官方合作伙伴为您服务|欢迎致电中国联通,北京二零二二年冬奥会官方合作伙伴为您服务 +充值缴费请按1,查询话费及余量请按2,跳过本次提醒请按井号键。|充值缴费请按一,查询话费及余量请按二,跳过本次提醒请按井号键。 +快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按9,查询话费、套餐余量、积分及活动返款请按1,手机上网流量开通及取消请按2,查询本机号码及本号所使用套餐请按4,密码修改及重置请按5,紧急开机请按6,挂失请按7,查询充值记录请按8,其它自助服务及人工服务请按0|快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按九,查询话费、套餐余量、积分及活动返款请按一,手机上网流量开通及取消请按二,查询本机号码及本号所使用套餐请按四,密码修改及重置请按五,紧急开机请按六,挂失请按七,查询充值记录请按八,其它自助服务及人工服务请按零 +智能客服助理快速查话费、查流量请按9,了解北京联通业务请按1,宽带IPTV新装、查询请按2,障碍报修请按3,充值缴费请按4,投诉建议请按5,政企业务请按7,人工服务请按0,for english severice press star key|智能客服助理快速查话费、查流量请按九,了解北京联通业务请按一,宽带IPTV新装、查询请按二,障碍报修请按三,充值缴费请按四,投诉建议请按五,政企业务请按七,人工服务请按零,for english severice press star key +您的帐户当前可用余额为63.89元,本月消费为2.17元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。|您的帐户当前可用余额为六十三点八九元,本月消费为二点一七元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。 +您的帐户当前可用余额为负15.5元,本月消费为59.6元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。|您的帐户当前可用余额为负十五点五元,本月消费为五十九点六元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。 +尊敬的客户,您目前的话费余额为负14.60元,已低于10元,为保证您的通信畅通,请及时缴纳费用。|尊敬的客户,您目前的话费余额为负十四点六元,已低于十元,为保证您的通信畅通,请及时缴纳费用。 +您的流量已用完,为避免您产生额外费用,建议您根据需求开通一个流量包以作补充。|您的流量已用完,为避免您产生额外费用,建议您根据需求开通一个流量包以作补充。 +您可以直接说,查询话费及余量、开通流量包、缴费,您也可以说出其它需求,请问有什么可以帮您?|您可以直接说,查询话费及余量、开通流量包、缴费,您也可以说出其它需求,请问有什么可以帮您? +您的账户当前可用余额为负36.00元,本月消费36.00元。|您的账户当前可用余额为负三十六元,本月消费三十六元。 +请问你是电话13985608526的机主吗?|请问你是电话一三九八五六零八五二六的机主吗? +如您对处理结果不满意,可拨打中国联通集团投诉电话10015进行投诉,按本地通话费收费,返回自助服务请按井号键|如您对处理结果不满意,可拨打中国联通集团投诉电话一零零一五进行投诉,按本地通话费收费,返回自助服务请按井号键 +“26314”号VIP客服代表为您服务。|“二六三一四”号VIP客服代表为您服务。 +尊敬的5G用户,欢迎您致电中国联通|尊敬的五G用户,欢迎您致电中国联通 +首先是应用了M1芯片的iPad Pro,新款的iPad Pro支持5G,这也是苹果的第二款5G产品线。|首先是应用了M一芯片的iPad Pro,新款的iPad Pro支持五G,这也是苹果的第二款五G产品线。 +除此之外,摄像头方面再次升级,增加了前摄全新超广角摄像头,支持人物居中功能,搭配超广角可实现视频中始终让人物居中效果。|除此之外,摄像头方面再次升级,增加了前摄全新超广角摄像头,支持人物居中功能,搭配超广角可实现视频中始终让人物居中效果。 +屏幕方面,iPad Pro 12.9版本支持XDR体验的Mini-LEDS显示屏,支持HDR10、杜比视界,还支持杜比全景声。|屏幕方面,iPad Pro 十二点九版本支持XDR体验的Mini-LEDS显示屏,支持HDR十、杜比视界,还支持杜比全景声。 +iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这次也推出白色版本。 +售价方面,11英寸版本售价799美元起,12.9英寸售价1099美元起。|售价方面,十一英寸版本售价七百九十九美元起,十二点九英寸售价一千零九十九美元起。 +这块黄金重达324.75克|这块黄金重达三百二十四点七五克 +她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日,她弟弟出生于一九九五年三月一日 +电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九 +现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票 +随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万 +明天有62%的概率降雨|明天有百分之六十二的概率降雨 +这是固话0421-33441122|这是固话零四二一三三四四一一二二 +这是手机+86 18544139121|这是手机八六一八五四四一三九一二一 +小王的身高是153.5cm,梦想是打篮球!我觉得有0.1%的可能性。|小王的身高是一百五十三点五cm,梦想是打篮球!我觉得有百分之零点一的可能性。 +不管三七二十一|不管三七二十一 +九九八十一难|九九八十一难 +2018年5月23号上午10点10分|二零一八年五月二十三号上午十点十分 +10076|一零零七六 +32.68%|百分之三十二点六八 +比分测试17:16|比分测试十七比十六 +比分测试37:16|比分测试三十七比十六 +1.1|一点一 +一点一滴|一点一滴 +八九十|八九十 +1个人一定要|一个人一定要 +10000棵树|一万棵树 +1234个人|一千二百三十四个人 +35553座楼|三万五千五百五十三座楼 +15873690|一五八七三六九零 +27930122|二七九三零一二二 +85307499|八五三零七四九九 +26149787|二六一四九七八七 +15964862|一五九六四八六二 +45698723|四五六九八七二三 +48615964|四八六一五九六四 +17864589|一七八六四五八九 +123加456|一百二十三加四百五十六 +9786加3384|九千七百八十六加三千三百八十四 +发电站每天发电30029度电|发电站每天发电三万零二十九度电 +银行月交易总额七千九百零三亿元|银行月交易总额七千九百零三亿元 +深圳每月平均工资在13000元|深圳每月平均工资在一万三千元 +每月房租要交1500元|每月房租要交一千五百元 +我每月交通费用在400元左右|我每月交通费用在四百元左右 +本月开销费用是51328元|本月开销费用是五万一千三百二十八元 +如果你中了五千万元奖金会分我一半吗|如果你中了五千万元奖金会分我一半吗 +这个月工资我发了3529元|这个月工资我发了三千五百二十九元 +学会了这个技能你至少可以涨薪5000元|学会了这个技能你至少可以涨薪五千元 +我们的会议时间定在9点25分开始|我们的会议时间定在九点二十五分开始 +上课时间是8点15分请不要迟到|上课时间是八点十五分请不要迟到 +昨天你9点21分才到教室|昨天你九点二十一分才到教室 +今天是2019年1月31号|今天是二零一九年一月三十一号 +今年的除夕夜是2019年2月4号|今年的除夕夜是二零一九年二月四号 +这根水管的长度不超过35米|这根水管的长度不超过三十五米 +400米是最短的长跑距离|四百米是最短的长跑距离 +最高的撑杆跳为11米|最高的撑杆跳为十一米 +等会请在12:05请通知我|等会请在十二点零五分请通知我 +23点15分开始|二十三点十五分开始 +你生日那天我会送你999朵玫瑰|你生日那天我会送你九百九十九朵玫瑰 +给我1双鞋我可以跳96米远|给我一双鞋我可以跳九十六米远 +虽然我们的身高相差356毫米也不影响我们交往|虽然我们的身高相差三百五十六毫米也不影响我们交往 +我们班的最高总分为583分|我们班的最高总分为五百八十三分 +今天考试老师多扣了我21分|今天考试老师多扣了我二十一分 +我量过这张桌子总长为1.37米|我量过这张桌子总长为一点三七米 +乘务员身高必须超过185公分|乘务员身高必须超过一百八十五公分 +这台电脑分辨率为1024|这台电脑分辨率为一零二四 +手机价格不超过1500元|手机价格不超过一千五百元 +101.23|一百零一点二三 +123.116|一百二十三点一一六 +456.147|四百五十六点一四七 +0.1594|零点一五九四 +3.1415|三点一四一五 +0.112233|零点一一二二三三 +0.1|零点一 +40001.987|四万零一点九八七 +56.878|五十六点八七八 +0.00123|零点零零一二三 +0.0001|零点零零零一 +0.92015|零点九二零一五 +999.0001|九百九十九点零零零一 +10000.123|一万点一二三 +666.555|六百六十六点五五五 +444.789|四百四十四点七八九 +789.666|七百八十九点六六六 +0.12345|零点一二三四五 +1.05649|一点零五六四九 +环比上调1.86%|环比上调百分之一点八六 +环比分别下跌3.46%及微涨0.70%|环比分别下跌百分之三点四六及微涨百分之零点七 +单价在30000元的二手房购房个案当中|单价在三万元的二手房购房个案当中 +6月仍有7%单价在30000元的房源|六月仍有百分之七单价在三万元的房源 +最终也只是以总积分1分之差屈居第2|最终也只是以总积分一分之差屈居第二 +中新网8月29日电今日|中新网八月二十九日电今日 +自6月底呼和浩特市率先宣布取消限购后|自六月底呼和浩特市率先宣布取消限购后 +仅1个多月的时间里|仅一个多月的时间里 +除了北京上海广州深圳4个一线城市和三亚之外|除了北京上海广州深圳四个一线城市和三亚之外 +46个限购城市当中|四十六个限购城市当中 +41个已正式取消或变相放松了限购|四十一个已正式取消或变相放松了限购 +其中包括对拥有一套住房并已结清相应购房贷款的家庭|其中包括对拥有一套住房并已结清相应购房贷款的家庭 +这个后来被称为930新政策的措施|这个后来被称为九三零新政策的措施 +今年有望超三百亿美元|今年有望超三百亿美元 +就连一向看多的任志强|就连一向看多的任志强 +近期也一反常态地发表看空言论|近期也一反常态地发表看空言论 +985|九八五 \ No newline at end of file diff --git a/examples/text_frontend/run.sh b/examples/text_frontend/run.sh new file mode 100755 index 0000000..01b0d72 --- /dev/null +++ b/examples/text_frontend/run.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# test g2p +echo "Start test g2p." +python3 test_g2p.py --root-dir=~/datasets/BZNSYP +# test text normalization +echo "Start test text normalization." +python3 test_textnorm.py --test-file=data/textnorm_test_cases.txt \ No newline at end of file diff --git a/examples/text_frontend/test_g2p.py b/examples/text_frontend/test_g2p.py new file mode 100644 index 0000000..ae4eb0b --- /dev/null +++ b/examples/text_frontend/test_g2p.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from collections import defaultdict +from pathlib import Path + +from parakeet.frontend.cn_frontend import Frontend as cnFrontend +from parakeet.utils.error_rate import wer +from praatio import tgio + + +def text_cleaner(raw_text): + text = re.sub('#[1-4]|“|”|(|)', '', raw_text) + text = text.replace("…。", "。") + text = re.sub(':|;|——|……|、|…|—', ',', text) + return text + + +def get_baker_data(root_dir): + alignment_files = sorted( + list((root_dir / "PhoneLabeling").rglob("*.interval"))) + text_file = root_dir / "ProsodyLabeling/000001-010000.txt" + text_file = Path(text_file).expanduser() + data_dict = defaultdict(dict) + # filter out several files that have errors in annotation + exclude = {'000611', '000662', '002365', '005107'} + alignment_files = [f for f in alignment_files if f.stem not in exclude] + # biaobei 前后有 sil ,中间没有 sp + data_dict = defaultdict(dict) + for alignment_fp in alignment_files: + alignment = tgio.openTextgrid(alignment_fp) + # only with baker's annotation + utt_id = alignment.tierNameList[0].split(".")[0] + intervals = alignment.tierDict[alignment.tierNameList[0]].entryList + phones = [] + for interval in intervals: + label = interval.label + # Baker has sp1 rather than sp + label = label.replace("sp1", "sp") + phones.append(label) + data_dict[utt_id]["phones"] = phones + for line in open(text_file, "r"): + if line.startswith("0"): + utt_id, raw_text = line.strip().split() + text = text_cleaner(raw_text) + if utt_id in data_dict: + data_dict[utt_id]['text'] = text + else: + pinyin = line.strip().split() + if utt_id in data_dict: + data_dict[utt_id]['pinyin'] = pinyin + return data_dict + + +def get_g2p_phones(data_dict, frontend): + for utt_id in data_dict: + g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text']) + data_dict[utt_id]["g2p_phones"] = g2p_phones + return data_dict + + +def get_avg_wer(data_dict): + wer_list = [] + for utt_id in data_dict: + g2p_phones = data_dict[utt_id]['g2p_phones'] + # delete silence tokens in predicted phones + g2p_phones = [phn for phn in g2p_phones if phn not in {"sp", "sil"}] + gt_phones = data_dict[utt_id]['phones'] + # delete silence tokens in baker phones + gt_phones = [phn for phn in gt_phones if phn not in {"sp", "sil"}] + gt_phones = " ".join(gt_phones) + g2p_phones = " ".join(g2p_phones) + single_wer = wer(gt_phones, g2p_phones) + wer_list.append(single_wer) + return sum(wer_list) / len(wer_list) + + +def main(): + parser = argparse.ArgumentParser(description="g2p example.") + parser.add_argument( + "--root-dir", + default=None, + type=str, + help="directory to baker dataset.") + + args = parser.parse_args() + root_dir = Path(args.root_dir).expanduser() + assert root_dir.is_dir() + frontend = cnFrontend() + data_dict = get_baker_data(root_dir) + data_dict = get_g2p_phones(data_dict, frontend) + avg_wer = get_avg_wer(data_dict) + print("The avg WER of g2p is:", avg_wer) + + +if __name__ == "__main__": + main() diff --git a/examples/text_frontend/test_textnorm.py b/examples/text_frontend/test_textnorm.py new file mode 100644 index 0000000..7242a3e --- /dev/null +++ b/examples/text_frontend/test_textnorm.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from pathlib import Path + +from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer +from parakeet.utils.error_rate import cer + + +# delete english characters +# e.g. "你好aBC" -> "你 好" +def del_en_add_space(input: str): + output = re.sub('[a-zA-Z]', '', input) + output = [char + " " for char in output] + output = "".join(output).strip() + return output + + +def get_avg_cer(test_file, text_normalizer): + cer_list = [] + for line in open(test_file, "r"): + line = line.strip() + raw_text, gt_text = line.split("|") + textnorm_text = text_normalizer.normalize_sentence(raw_text) + gt_text = del_en_add_space(gt_text) + textnorm_text = del_en_add_space(textnorm_text) + single_cer = cer(gt_text, textnorm_text) + cer_list.append(single_cer) + return sum(cer_list) / len(cer_list) + + +def main(): + parser = argparse.ArgumentParser(description="text normalization example.") + parser.add_argument( + "--test-file", + default=None, + type=str, + help="path of text normalization test file.") + + args = parser.parse_args() + test_file = Path(args.test_file).expanduser() + text_normalizer = TextNormalizer() + avg_cer = get_avg_cer(test_file, text_normalizer) + print("The avg CER of text normalization is:", avg_cer) + + +if __name__ == "__main__": + main() diff --git a/parakeet/frontend/cn_frontend.py b/parakeet/frontend/cn_frontend.py index 52624e0..12b2b84 100644 --- a/parakeet/frontend/cn_frontend.py +++ b/parakeet/frontend/cn_frontend.py @@ -35,6 +35,14 @@ class Frontend(): self.g2pM_model = G2pM() self.pinyin2phone = generate_lexicon( with_tone=True, with_erhua=False) + self.must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"} + self.not_erhua = { + "虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿", + "拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿", + "流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿", + "孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿", + "狗儿" + } def _get_initials_finals(self, word): initials = [] @@ -71,26 +79,31 @@ class Frontend(): return initials, finals # if merge_sentences, merge all sentences into one phone sequence - def _g2p(self, sentences, merge_sentences=True): + def _g2p(self, sentences, merge_sentences=True, with_erhua=True): segments = sentences phones_list = [] for seg in segments: phones = [] - seg = psg.lcut(seg) + seg_cut = psg.lcut(seg) initials = [] finals = [] - seg = self.tone_modifier.pre_merge_for_modify(seg) - for word, pos in seg: + seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut) + for word, pos in seg_cut: if pos == 'eng': continue sub_initials, sub_finals = self._get_initials_finals(word) + sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals) + if with_erhua: + sub_initials, sub_finals = self._merge_erhua( + sub_initials, sub_finals, word, pos) initials.append(sub_initials) finals.append(sub_finals) # assert len(sub_initials) == len(sub_finals) == len(word) initials = sum(initials, []) finals = sum(finals, []) + for c, v in zip(initials, finals): # NOTE: post process for pypinyin outputs # we discriminate i, ii and iii @@ -106,7 +119,24 @@ class Frontend(): phones_list = sum(phones_list, []) return phones_list - def get_phonemes(self, sentence): + def _merge_erhua(self, initials, finals, word, pos): + if word not in self.must_erhua and (word in self.not_erhua or + pos in {"a", "j", "nr"}): + return initials, finals + new_initials = [] + new_finals = [] + assert len(finals) == len(word) + for i, phn in enumerate(finals): + if i == len(finals) - 1 and word[i] == "儿" and phn in { + "er2", "er5" + } and word[-2:] not in self.not_erhua and new_finals: + new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1] + else: + new_finals.append(phn) + new_initials.append(initials[i]) + return new_initials, new_finals + + def get_phonemes(self, sentence, with_erhua=True): sentences = self.text_normalizer.normalize(sentence) - phonemes = self._g2p(sentences) + phonemes = self._g2p(sentences, with_erhua=with_erhua) return phonemes diff --git a/parakeet/frontend/cn_normalization/num.py b/parakeet/frontend/cn_normalization/num.py index 459d871..66932af 100644 --- a/parakeet/frontend/cn_normalization/num.py +++ b/parakeet/frontend/cn_normalization/num.py @@ -29,6 +29,8 @@ UNITS = OrderedDict({ 8: '亿', }) +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' + # 分数表达式 RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') @@ -59,7 +61,17 @@ def replace_percentage(match: re.Match) -> str: # 整数表达式 # 带负号或者不带负号的整数 12, -10 -RE_INTEGER = re.compile(r'(-?)' r'(\d+)') +RE_INTEGER = re.compile(r'(-)' r'(\d+)') + + +def replace_negative_num(match: re.Match) -> str: + sign = match.group(1) + number = match.group(2) + sign: str = "负" if sign else "" + number: str = num2str(number) + result = f"{sign}{number}" + return result + # 编号-无符号整形 # 00078 @@ -72,12 +84,23 @@ def replace_default_num(match: re.Match): # 数字表达式 -# 1. 整数: -10, 10; -# 2. 浮点数: 10.2, -0.3 -# 3. 不带符号和整数部分的纯浮点数: .22, .38 +# 纯小数 +RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') +# 正整数 + 量词 +RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几])?" + COM_QUANTIFIERS) RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))') +def replace_positive_quantifier(match: re.Match) -> str: + number = match.group(1) + match_2 = match.group(2) + match_2: str = match_2 if match_2 else "" + quantifiers: str = match.group(3) + number: str = num2str(number) + result = f"{number}{match_2}{quantifiers}" + return result + + def replace_number(match: re.Match) -> str: sign = match.group(1) number = match.group(2) @@ -93,7 +116,7 @@ def replace_number(match: re.Match) -> str: # 范围表达式 # 12-23, 12~23 -RE_RANGE = re.compile(r'(\d+)[-~](\d+)') +RE_RANGE = re.compile(r'(\d+)[~](\d+)') def replace_range(match: re.Match) -> str: diff --git a/parakeet/frontend/cn_normalization/phonecode.py b/parakeet/frontend/cn_normalization/phonecode.py index 7539555..354e463 100644 --- a/parakeet/frontend/cn_normalization/phonecode.py +++ b/parakeet/frontend/cn_normalization/phonecode.py @@ -25,7 +25,7 @@ from .num import verbalize_digit RE_MOBILE_PHONE = re.compile( r"(? str: @@ -44,4 +44,8 @@ def phone2str(phone_string: str, mobile=True) -> str: def replace_phone(match: re.Match) -> str: + return phone2str(match.group(0), mobile=False) + + +def replace_mobile(match: re.Match) -> str: return phone2str(match.group(0)) diff --git a/parakeet/frontend/cn_normalization/text_normlization.py b/parakeet/frontend/cn_normalization/text_normlization.py index 56583b3..fbae106 100644 --- a/parakeet/frontend/cn_normalization/text_normlization.py +++ b/parakeet/frontend/cn_normalization/text_normlization.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import opencc import re from typing import List from .chronology import RE_TIME, RE_DATE, RE_DATE2 from .chronology import replace_time, replace_date, replace_date2 from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE -from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM -from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num -from .phonecode import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone +from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM, RE_DECIMAL_NUM, RE_POSITIVE_QUANTIFIERS +from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num, replace_negative_num, replace_positive_quantifier +from .phonecode import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone, replace_mobile from .quantifier import RE_TEMPERATURE from .quantifier import replace_temperature @@ -29,8 +28,6 @@ from .quantifier import replace_temperature class TextNormalizer(): def __init__(self): self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)') - self._t2s_converter = opencc.OpenCC("t2s.json") - self._s2t_converter = opencc.OpenCC('s2t.json') def _split(self, text: str) -> List[str]: """Split long text into sentences with sentence-splitting punctuations. @@ -48,15 +45,8 @@ class TextNormalizer(): sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] return sentences - def _tranditional_to_simplified(self, text: str) -> str: - return self._t2s_converter.convert(text) - - def _simplified_to_traditional(self, text: str) -> str: - return self._s2t_converter.convert(text) - def normalize_sentence(self, sentence): # basic character conversions - sentence = self._tranditional_to_simplified(sentence) sentence = sentence.translate(F2H_ASCII_LETTERS).translate( F2H_DIGITS).translate(F2H_SPACE) @@ -68,8 +58,12 @@ class TextNormalizer(): sentence = RE_RANGE.sub(replace_range, sentence) sentence = RE_FRAC.sub(replace_frac, sentence) sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) - sentence = RE_MOBILE_PHONE.sub(replace_phone, sentence) + sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) sentence = RE_TELEPHONE.sub(replace_phone, sentence) + sentence = RE_INTEGER.sub(replace_negative_num, sentence) + sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) + sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, + sentence) sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) sentence = RE_NUMBER.sub(replace_number, sentence) diff --git a/parakeet/frontend/tone_sandhi.py b/parakeet/frontend/tone_sandhi.py index a03989c..9dc3917 100644 --- a/parakeet/frontend/tone_sandhi.py +++ b/parakeet/frontend/tone_sandhi.py @@ -56,102 +56,133 @@ class ToneSandhi(): '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤', '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', '交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', - '不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个' + '不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨', + '父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅', + '幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱', + '凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱', + '扫把', '惦记' + } + self.must_not_neural_tone_words = { + "男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子" } # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041 # e.g. # word: "家里" - # pos: "s" + # pos: "s" # finals: ['ia1', 'i3'] def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]: + + # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 + for j, item in enumerate(word): + if j - 1 >= 0 and item == word[j - 1] and pos[ + 0] in {"n", "v", "a"}: + finals[j] = finals[j][:-1] + "5" ge_idx = word.find("个") - if len(word) == 1 and word in "吧呢啊嘛" and pos == 'y': + if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶": finals[-1] = finals[-1][:-1] + "5" - elif len(word) == 1 and word in "的地得" and pos in {"ud", "uj", "uv"}: + elif len(word) >= 1 and word[-1] in "的地得": finals[-1] = finals[-1][:-1] + "5" # e.g. 走了, 看着, 去过 elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}: finals[-1] = finals[-1][:-1] + "5" - elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"}: + elif len(word) > 1 and word[-1] in "们子" and pos in { + "r", "n" + } and word not in self.must_not_neural_tone_words: finals[-1] = finals[-1][:-1] + "5" # e.g. 桌上, 地下, 家里 elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}: finals[-1] = finals[-1][:-1] + "5" # e.g. 上来, 下去 - elif len(word) > 1 and word[-1] in "来去" and pos[0] in {"v"}: + elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开": finals[-1] = finals[-1][:-1] + "5" # 个做量词 - elif ge_idx >= 1 and word[ge_idx - 1].isnumeric(): + elif (ge_idx >= 1 and + (word[ge_idx - 1].isnumeric() or + word[ge_idx - 1] in "几有两半多各整每做是")) or word == '个': finals[ge_idx] = finals[ge_idx][:-1] + "5" - # reduplication words for n. and v. e.g. 奶奶, 试试 - elif len(word) >= 2 and word[-1] == word[-2] and pos[0] in {"n", "v"}: - finals[-1] = finals[-1][:-1] + "5" - # conventional tone5 in Chinese - elif word in self.must_neural_tone_words or word[ - -2:] in self.must_neural_tone_words: - finals[-1] = finals[-1][:-1] + "5" + else: + if word in self.must_neural_tone_words or word[ + -2:] in self.must_neural_tone_words: + finals[-1] = finals[-1][:-1] + "5" + + word_list = self._split_word(word) + finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]] + for i, word in enumerate(word_list): + # conventional neural in Chinese + if word in self.must_neural_tone_words or word[ + -2:] in self.must_neural_tone_words: + finals_list[i][-1] = finals_list[i][-1][:-1] + "5" + finals = sum(finals_list, []) return finals def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: - # "不" before tone4 should be bu2, e.g. 不怕 - if len(word) > 1 and word[0] == "不" and finals[1][-1] == "4": - finals[0] = finals[0][:-1] + "2" # e.g. 看不懂 - elif len(word) == 3 and word[1] == "不": + if len(word) == 3 and word[1] == "不": finals[1] = finals[1][:-1] + "5" - + else: + for i, char in enumerate(word): + # "不" before tone4 should be bu2, e.g. 不怕 + if char == "不" and i + 1 < len(word) and finals[i + 1][ + -1] == "4": + finals[i] = finals[i][:-1] + "2" return finals def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: - # "一" in number sequences, e.g. 一零零 - if len(word) > 1 and word[0] == "一" and all( - [item.isnumeric() for item in word]): + # "一" in number sequences, e.g. 一零零, 二一零 + if word.find("一") != -1 and all( + [item.isnumeric() for item in word if item != "一"]): return finals - # "一" before tone4 should be yi2, e.g. 一段 - elif len(word) > 1 and word[0] == "一" and finals[1][-1] == "4": - finals[0] = finals[0][:-1] + "2" - # "一" before non-tone4 should be yi4, e.g. 一天 - elif len(word) > 1 and word[0] == "一" and finals[1][-1] != "4": - finals[0] = finals[0][:-1] + "4" # "一" between reduplication words shold be yi5, e.g. 看一看 elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]: finals[1] = finals[1][:-1] + "5" # when "一" is ordinal word, it should be yi1 elif word.startswith("第一"): finals[1] = finals[1][:-1] + "1" + else: + for i, char in enumerate(word): + if char == "一" and i + 1 < len(word): + # "一" before tone4 should be yi2, e.g. 一段 + if finals[i + 1][-1] == "4": + finals[i] = finals[i][:-1] + "2" + # "一" before non-tone4 should be yi4, e.g. 一天 + else: + finals[i] = finals[i][:-1] + "4" return finals + def _split_word(self, word): + word_list = jieba.cut_for_search(word) + word_list = sorted(word_list, key=lambda i: len(i), reverse=False) + new_word_list = [] + + first_subword = word_list[0] + first_begin_idx = word.find(first_subword) + if first_begin_idx == 0: + second_subword = word[len(first_subword):] + new_word_list = [first_subword, second_subword] + else: + second_subword = word[:-len(first_subword)] + new_word_list = [second_subword, first_subword] + return new_word_list + def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: if len(word) == 2 and self._all_tone_three(finals): finals[0] = finals[0][:-1] + "2" elif len(word) == 3: - word_list = jieba.cut_for_search(word) - word_list = sorted(word_list, key=lambda i: len(i), reverse=False) - new_word_list = [] - first_subword = word_list[0] - first_begin_idx = word.find(first_subword) - if first_begin_idx == 0: - second_subword = word[len(first_subword):] - new_word_list = [first_subword, second_subword] - else: - second_subword = word[:-len(first_subword)] - - new_word_list = [second_subword, first_subword] + word_list = self._split_word(word) if self._all_tone_three(finals): # disyllabic + monosyllabic, e.g. 蒙古/包 - if len(new_word_list[0]) == 2: + if len(word_list[0]) == 2: finals[0] = finals[0][:-1] + "2" finals[1] = finals[1][:-1] + "2" # monosyllabic + disyllabic, e.g. 纸/老虎 - elif len(new_word_list[0]) == 1: + elif len(word_list[0]) == 1: finals[1] = finals[1][:-1] + "2" else: finals_list = [ - finals[:len(new_word_list[0])], - finals[len(new_word_list[0]):] + finals[:len(word_list[0])], finals[len(word_list[0]):] ] if len(finals_list) == 2: for i, sub in enumerate(finals_list): @@ -192,11 +223,10 @@ class ToneSandhi(): if last_word == "不": new_seg.append((last_word, 'd')) last_word = "" - seg = new_seg - return seg + return new_seg # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听" - # function 2: merge single "一" and the word behind it + # function 2: merge single "一" and the word behind it # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error # e.g. # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')] @@ -222,9 +252,9 @@ class ToneSandhi(): new_seg[-1][0] = new_seg[-1][0] + word else: new_seg.append([word, pos]) - seg = new_seg - return seg + return new_seg + # the first and the second words are all_tone_three def _merge_continuous_three_tones( self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg = [] @@ -239,21 +269,73 @@ class ToneSandhi(): if i - 1 >= 0 and self._all_tone_three(sub_finals_list[ i - 1]) and self._all_tone_three(sub_finals_list[ i]) and not merge_last[i - 1]: - if len(seg[i - 1][0]) + len(seg[i][0]) <= 3: + # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi + if not self._is_reduplication(seg[i - 1][0]) and len(seg[ + i - 1][0]) + len(seg[i][0]) <= 3: new_seg[-1][0] = new_seg[-1][0] + seg[i][0] merge_last[i] = True else: new_seg.append([word, pos]) else: new_seg.append([word, pos]) - seg = new_seg - return seg + + return new_seg + + def _is_reduplication(self, word): + return len(word) == 2 and word[0] == word[1] + + # the last char of first word and the first char of second word is tone_three + def _merge_continuous_three_tones_2( + self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + sub_finals_list = [ + lazy_pinyin( + word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + for (word, pos) in seg + ] + assert len(sub_finals_list) == len(seg) + merge_last = [False] * len(seg) + for i, (word, pos) in enumerate(seg): + if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \ + merge_last[i - 1]: + # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi + if not self._is_reduplication(seg[i - 1][0]) and len(seg[ + i - 1][0]) + len(seg[i][0]) <= 3: + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + merge_last[i] = True + else: + new_seg.append([word, pos]) + else: + new_seg.append([word, pos]) + return new_seg + + def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + for i, (word, pos) in enumerate(seg): + if i - 1 >= 0 and word == "儿": + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + else: + new_seg.append([word, pos]) + return new_seg + + def _merge_reduplication( + self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + new_seg = [] + for i, (word, pos) in enumerate(seg): + if new_seg and word == new_seg[-1][0]: + new_seg[-1][0] = new_seg[-1][0] + seg[i][0] + else: + new_seg.append([word, pos]) + return new_seg def pre_merge_for_modify( self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: seg = self._merge_bu(seg) seg = self._merge_yi(seg) + seg = self._merge_reduplication(seg) seg = self._merge_continuous_three_tones(seg) + seg = self._merge_continuous_three_tones_2(seg) + seg = self._merge_er(seg) return seg def modified_tone(self, word: str, pos: str, diff --git a/parakeet/utils/error_rate.py b/parakeet/utils/error_rate.py new file mode 100644 index 0000000..7a9fe5a --- /dev/null +++ b/parakeet/utils/error_rate.py @@ -0,0 +1,239 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides functions to calculate error rate in different level. +e.g. wer for word-level, cer for char-level. +""" +import numpy as np + +__all__ = ['word_errors', 'char_errors', 'wer', 'cer'] + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(ref) + n = len(hyp) + + # special case + if ref == hyp: + return 0 + if m == 0: + return n + if n == 0: + return m + + if m < n: + ref, hyp = hyp, ref + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in range(n + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in range(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in range(1, n + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + + Parameters + ---------- + reference : str + The reference sentence. + hypothesis : str + The hypothesis sentence. + ignore_case : bool + Whether case-sensitive or not. + delimiter : char(str) + Delimiter of input sentences. + + Returns + ---------- + list + Levenshtein distance and word number of reference sentence. + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = list(filter(None, reference.split(delimiter))) + hyp_words = list(filter(None, hypothesis.split(delimiter))) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + + Parameters + ---------- + reference: str + The reference sentence. + hypothesis: str + The hypothesis sentence. + ignore_case: bool + Whether case-sensitive or not. + remove_space: bool + Whether remove internal space characters + + Returns + ---------- + list + Levenshtein distance and length of reference sentence. + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = ' ' + if remove_space: + join_char = '' + + reference = join_char.join(list(filter(None, reference.split(' ')))) + hypothesis = join_char.join(list(filter(None, hypothesis.split(' ')))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference, hypothesis, ignore_case=False, delimiter=' '): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + .. math:: + WER = (Sw + Dw + Iw) / Nw + where + .. code-block:: text + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + We can use levenshtein distance to calculate WER. Please draw an attention + that empty items will be removed when splitting sentences by delimiter. + + Parameters + ---------- + reference: str + The reference sentence. + + hypothesis: str + The hypothesis sentence. + ignore_case: bool + Whether case-sensitive or not. + delimiter: char + Delimiter of input sentences. + + Returns + ---------- + float + Word error rate. + + Raises + ---------- + ValueError + If word number of reference is zero. + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, + delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + wer = float(edit_distance) / ref_len + return wer + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + .. math:: + CER = (Sc + Dc + Ic) / Nc + where + .. code-block:: text + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + space characters will be truncated and multiple consecutive space + characters in a sentence will be replaced by one space character. + + Parameters + ---------- + reference: str + The reference sentence. + hypothesis: str + The hypothesis sentence. + ignore_case: bool + Whether case-sensitive or not. + remove_space: bool + Whether remove internal space characters + + Returns + ---------- + float + Character error rate. + + Raises + ---------- + ValueError + If the reference length is zero. + """ + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, + remove_space) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len + return cer + + +if __name__ == "__main__": + reference = [ + 'j', 'iou4', 'zh', 'e4', 'iang5', 'x', 'v2', 'b', 'o1', 'k', 'ai1', + 'sh', 'iii3', 'l', 'e5', 'b', 'ei3', 'p', 'iao1', 'sh', 'eng1', 'ia2' + ] + hypothesis = [ + 'j', 'iou4', 'zh', 'e4', 'iang4', 'x', 'v2', 'b', 'o1', 'k', 'ai1', + 'sh', 'iii3', 'l', 'e5', 'b', 'ei3', 'p', 'iao1', 'sh', 'eng1', 'ia2' + ] + reference = " ".join(reference) + hypothesis = " ".join(hypothesis) + print(wer(reference, hypothesis)) diff --git a/setup.py b/setup.py index b7cb4da..8b48a8d 100644 --- a/setup.py +++ b/setup.py @@ -71,10 +71,13 @@ setup_info = dict( 'pypinyin', 'webrtcvad', 'g2pM', - 'praatio', + 'praatio~=4.1', "h5py", "timer", 'jsonlines', + 'pyworld', + 'typeguard', + 'jieba', ], extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },