add text frontend example
This commit is contained in:
parent
3ac2e01263
commit
309228ddbf
|
@ -1,5 +1,3 @@
|
||||||
|
|
||||||
|
|
||||||
# FastSpeech2 with BZNSYP
|
# FastSpeech2 with BZNSYP
|
||||||
|
|
||||||
## Dataset
|
## Dataset
|
||||||
|
|
|
@ -12,10 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from yacs.config import CfgNode as Configuration
|
from yacs.config import CfgNode as Configuration
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
with open("conf/default.yaml", 'rt') as f:
|
config_path = (Path(__file__).parent / "conf" / "default.yaml").resolve()
|
||||||
|
|
||||||
|
with open(config_path, 'rt') as f:
|
||||||
_C = yaml.safe_load(f)
|
_C = yaml.safe_load(f)
|
||||||
_C = Configuration(_C)
|
_C = Configuration(_C)
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,17 @@ class Frontend():
|
||||||
# split tone from finals
|
# split tone from finals
|
||||||
match = re.match(r'^(\w+)([012345])$', full_phone)
|
match = re.match(r'^(\w+)([012345])$', full_phone)
|
||||||
if match:
|
if match:
|
||||||
phones.append(match.group(1))
|
phone = match.group(1)
|
||||||
tones.append(match.group(2))
|
tone = match.group(2)
|
||||||
|
# if the merged erhua not in the vocab
|
||||||
|
if len(phone) >= 2 and phone != "er" and phone[
|
||||||
|
-1] == 'r' and phone not in self.vocab_phones and phone[:
|
||||||
|
-1] in self.vocab_phones:
|
||||||
|
phones.append(phone[:-1])
|
||||||
|
phones.append("er")
|
||||||
|
else:
|
||||||
|
tones.append(tone)
|
||||||
|
tones.append("2")
|
||||||
else:
|
else:
|
||||||
phones.append(full_phone)
|
phones.append(full_phone)
|
||||||
tones.append('0')
|
tones.append('0')
|
||||||
|
@ -67,7 +76,17 @@ class Frontend():
|
||||||
tone_ids = paddle.to_tensor(tone_ids)
|
tone_ids = paddle.to_tensor(tone_ids)
|
||||||
result["tone_ids"] = tone_ids
|
result["tone_ids"] = tone_ids
|
||||||
else:
|
else:
|
||||||
phones = phonemes
|
# if the merged erhua not in the vocab
|
||||||
|
phones = []
|
||||||
|
for phone in phonemes:
|
||||||
|
if len(phone) >= 3 and phone[:-1] != "er" and phone[
|
||||||
|
-2] == 'r' and phone not in self.vocab_phones and (
|
||||||
|
phone[:-2] + phone[-1]) in self.vocab_phones:
|
||||||
|
phones.append((phone[:-2] + phone[-1]))
|
||||||
|
phones.append("er2")
|
||||||
|
else:
|
||||||
|
phones.append(phone)
|
||||||
|
|
||||||
phone_ids = self._p2id(phones)
|
phone_ids = self._p2id(phones)
|
||||||
phone_ids = paddle.to_tensor(phone_ids)
|
phone_ids = paddle.to_tensor(phone_ids)
|
||||||
result["phone_ids"] = phone_ids
|
result["phone_ids"] = phone_ids
|
||||||
|
|
|
@ -267,7 +267,7 @@ def main():
|
||||||
type=str,
|
type=str,
|
||||||
help="directory to baker dataset.")
|
help="directory to baker dataset.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dur-path",
|
"--dur-file",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="path to baker durations.txt.")
|
help="path to baker durations.txt.")
|
||||||
|
@ -308,8 +308,13 @@ def main():
|
||||||
root_dir = Path(args.rootdir).expanduser()
|
root_dir = Path(args.rootdir).expanduser()
|
||||||
dumpdir = Path(args.dumpdir).expanduser()
|
dumpdir = Path(args.dumpdir).expanduser()
|
||||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dur_file = Path(args.dur_file).expanduser()
|
||||||
|
|
||||||
|
assert root_dir.is_dir()
|
||||||
|
assert dur_file.is_file()
|
||||||
|
|
||||||
|
sentences = get_phn_dur(dur_file)
|
||||||
|
|
||||||
sentences = get_phn_dur(args.dur_path)
|
|
||||||
deal_silence(sentences)
|
deal_silence(sentences)
|
||||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||||
get_input_token(sentences, phone_id_map_path)
|
get_input_token(sentences, phone_id_map_path)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
|
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
|
||||||
|
|
||||||
# extract features
|
# extract features
|
||||||
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-path durations.txt --num-cpu 4 --cut-sil True
|
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-file durations.txt --num-cpu 4 --cut-sil True
|
||||||
|
|
||||||
# # get features' stats(mean and std)
|
# # get features' stats(mean and std)
|
||||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
|
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Chinese Text Frontend Example
|
||||||
|
Here's an example for Chinese text frontend, including g2p and text normalization.
|
||||||
|
## G2P
|
||||||
|
For g2p, we use BZNSYP's phone label as the ground truth and we delete silence tokens in labels and predicted phones.
|
||||||
|
|
||||||
|
You should Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/index/source) and extract it. Assume the path to the dataset is `~/datasets/BZNSYP`.
|
||||||
|
|
||||||
|
We use `WER` as evaluation criterion.
|
||||||
|
## Text Normalization
|
||||||
|
For text normalization, the test data is `data/textnorm_test_cases.txt`, we use `|` as the separator of raw_data and normed_data.
|
||||||
|
|
||||||
|
We use `CER` as evaluation criterion.
|
||||||
|
## Start
|
||||||
|
Run the command below to get the results of test.
|
||||||
|
```bash
|
||||||
|
./run.sh
|
||||||
|
```
|
||||||
|
The `avg WER` of g2p is: 0.02785753389811866
|
||||||
|
|
||||||
|
The `avg CER` of text normalization is: 0.014229233983486172
|
|
@ -0,0 +1,123 @@
|
||||||
|
今天的最低气温达到-10°C.|今天的最低气温达到零下十度.
|
||||||
|
只要有33/4的人同意,就可以通过决议。|只要有四分之三十三的人同意,就可以通过决议。
|
||||||
|
1945年5月2日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。|一九四五年五月二日,苏联士兵在德国国会大厦上升起了胜利旗,象征着攻占柏林并战胜了纳粹德国。
|
||||||
|
4月16日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。|四月十六日,清晨的战斗以炮击揭幕,数以千计的大炮和喀秋莎火箭炮开始炮轰德军阵地,炮击持续了数天之久。
|
||||||
|
如果剩下的30.6%是过去,那么还有69.4%.|如果剩下的百分之三十点六是过去,那么还有百分之六十九点四.
|
||||||
|
事情发生在2020/03/31的上午8:00.|事情发生在二零二零年三月三十一日的上午八点.
|
||||||
|
警方正在找一支.22口径的手枪。|警方正在找一支点二二口径的手枪。
|
||||||
|
欢迎致电中国联通,北京2022年冬奥会官方合作伙伴为您服务|欢迎致电中国联通,北京二零二二年冬奥会官方合作伙伴为您服务
|
||||||
|
充值缴费请按1,查询话费及余量请按2,跳过本次提醒请按井号键。|充值缴费请按一,查询话费及余量请按二,跳过本次提醒请按井号键。
|
||||||
|
快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按9,查询话费、套餐余量、积分及活动返款请按1,手机上网流量开通及取消请按2,查询本机号码及本号所使用套餐请按4,密码修改及重置请按5,紧急开机请按6,挂失请按7,查询充值记录请按8,其它自助服务及人工服务请按0|快速解除流量封顶请按星号键,腾讯王卡产品介绍、使用说明、特权及活动请按九,查询话费、套餐余量、积分及活动返款请按一,手机上网流量开通及取消请按二,查询本机号码及本号所使用套餐请按四,密码修改及重置请按五,紧急开机请按六,挂失请按七,查询充值记录请按八,其它自助服务及人工服务请按零
|
||||||
|
智能客服助理快速查话费、查流量请按9,了解北京联通业务请按1,宽带IPTV新装、查询请按2,障碍报修请按3,充值缴费请按4,投诉建议请按5,政企业务请按7,人工服务请按0,for english severice press star key|智能客服助理快速查话费、查流量请按九,了解北京联通业务请按一,宽带IPTV新装、查询请按二,障碍报修请按三,充值缴费请按四,投诉建议请按五,政企业务请按七,人工服务请按零,for english severice press star key
|
||||||
|
您的帐户当前可用余额为63.89元,本月消费为2.17元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。|您的帐户当前可用余额为六十三点八九元,本月消费为二点一七元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。
|
||||||
|
您的帐户当前可用余额为负15.5元,本月消费为59.6元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。|您的帐户当前可用余额为负十五点五元,本月消费为五十九点六元。您的消费、套餐余量和其它信息将以短信形式下发,请您注意查收。谢谢使用,再见!。
|
||||||
|
尊敬的客户,您目前的话费余额为负14.60元,已低于10元,为保证您的通信畅通,请及时缴纳费用。|尊敬的客户,您目前的话费余额为负十四点六元,已低于十元,为保证您的通信畅通,请及时缴纳费用。
|
||||||
|
您的流量已用完,为避免您产生额外费用,建议您根据需求开通一个流量包以作补充。|您的流量已用完,为避免您产生额外费用,建议您根据需求开通一个流量包以作补充。
|
||||||
|
您可以直接说,查询话费及余量、开通流量包、缴费,您也可以说出其它需求,请问有什么可以帮您?|您可以直接说,查询话费及余量、开通流量包、缴费,您也可以说出其它需求,请问有什么可以帮您?
|
||||||
|
您的账户当前可用余额为负36.00元,本月消费36.00元。|您的账户当前可用余额为负三十六元,本月消费三十六元。
|
||||||
|
请问你是电话13985608526的机主吗?|请问你是电话一三九八五六零八五二六的机主吗?
|
||||||
|
如您对处理结果不满意,可拨打中国联通集团投诉电话10015进行投诉,按本地通话费收费,返回自助服务请按井号键|如您对处理结果不满意,可拨打中国联通集团投诉电话一零零一五进行投诉,按本地通话费收费,返回自助服务请按井号键
|
||||||
|
“26314”号VIP客服代表为您服务。|“二六三一四”号VIP客服代表为您服务。
|
||||||
|
尊敬的5G用户,欢迎您致电中国联通|尊敬的五G用户,欢迎您致电中国联通
|
||||||
|
首先是应用了M1芯片的iPad Pro,新款的iPad Pro支持5G,这也是苹果的第二款5G产品线。|首先是应用了M一芯片的iPad Pro,新款的iPad Pro支持五G,这也是苹果的第二款五G产品线。
|
||||||
|
除此之外,摄像头方面再次升级,增加了前摄全新超广角摄像头,支持人物居中功能,搭配超广角可实现视频中始终让人物居中效果。|除此之外,摄像头方面再次升级,增加了前摄全新超广角摄像头,支持人物居中功能,搭配超广角可实现视频中始终让人物居中效果。
|
||||||
|
屏幕方面,iPad Pro 12.9版本支持XDR体验的Mini-LEDS显示屏,支持HDR10、杜比视界,还支持杜比全景声。|屏幕方面,iPad Pro 十二点九版本支持XDR体验的Mini-LEDS显示屏,支持HDR十、杜比视界,还支持杜比全景声。
|
||||||
|
iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这次也推出白色版本。
|
||||||
|
售价方面,11英寸版本售价799美元起,12.9英寸售价1099美元起。|售价方面,十一英寸版本售价七百九十九美元起,十二点九英寸售价一千零九十九美元起。
|
||||||
|
这块黄金重达324.75克|这块黄金重达三百二十四点七五克
|
||||||
|
她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日,她弟弟出生于一九九五年三月一日
|
||||||
|
电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九
|
||||||
|
现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票
|
||||||
|
随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万
|
||||||
|
明天有62%的概率降雨|明天有百分之六十二的概率降雨
|
||||||
|
这是固话0421-33441122|这是固话零四二一三三四四一一二二
|
||||||
|
这是手机+86 18544139121|这是手机八六一八五四四一三九一二一
|
||||||
|
小王的身高是153.5cm,梦想是打篮球!我觉得有0.1%的可能性。|小王的身高是一百五十三点五cm,梦想是打篮球!我觉得有百分之零点一的可能性。
|
||||||
|
不管三七二十一|不管三七二十一
|
||||||
|
九九八十一难|九九八十一难
|
||||||
|
2018年5月23号上午10点10分|二零一八年五月二十三号上午十点十分
|
||||||
|
10076|一零零七六
|
||||||
|
32.68%|百分之三十二点六八
|
||||||
|
比分测试17:16|比分测试十七比十六
|
||||||
|
比分测试37:16|比分测试三十七比十六
|
||||||
|
1.1|一点一
|
||||||
|
一点一滴|一点一滴
|
||||||
|
八九十|八九十
|
||||||
|
1个人一定要|一个人一定要
|
||||||
|
10000棵树|一万棵树
|
||||||
|
1234个人|一千二百三十四个人
|
||||||
|
35553座楼|三万五千五百五十三座楼
|
||||||
|
15873690|一五八七三六九零
|
||||||
|
27930122|二七九三零一二二
|
||||||
|
85307499|八五三零七四九九
|
||||||
|
26149787|二六一四九七八七
|
||||||
|
15964862|一五九六四八六二
|
||||||
|
45698723|四五六九八七二三
|
||||||
|
48615964|四八六一五九六四
|
||||||
|
17864589|一七八六四五八九
|
||||||
|
123加456|一百二十三加四百五十六
|
||||||
|
9786加3384|九千七百八十六加三千三百八十四
|
||||||
|
发电站每天发电30029度电|发电站每天发电三万零二十九度电
|
||||||
|
银行月交易总额七千九百零三亿元|银行月交易总额七千九百零三亿元
|
||||||
|
深圳每月平均工资在13000元|深圳每月平均工资在一万三千元
|
||||||
|
每月房租要交1500元|每月房租要交一千五百元
|
||||||
|
我每月交通费用在400元左右|我每月交通费用在四百元左右
|
||||||
|
本月开销费用是51328元|本月开销费用是五万一千三百二十八元
|
||||||
|
如果你中了五千万元奖金会分我一半吗|如果你中了五千万元奖金会分我一半吗
|
||||||
|
这个月工资我发了3529元|这个月工资我发了三千五百二十九元
|
||||||
|
学会了这个技能你至少可以涨薪5000元|学会了这个技能你至少可以涨薪五千元
|
||||||
|
我们的会议时间定在9点25分开始|我们的会议时间定在九点二十五分开始
|
||||||
|
上课时间是8点15分请不要迟到|上课时间是八点十五分请不要迟到
|
||||||
|
昨天你9点21分才到教室|昨天你九点二十一分才到教室
|
||||||
|
今天是2019年1月31号|今天是二零一九年一月三十一号
|
||||||
|
今年的除夕夜是2019年2月4号|今年的除夕夜是二零一九年二月四号
|
||||||
|
这根水管的长度不超过35米|这根水管的长度不超过三十五米
|
||||||
|
400米是最短的长跑距离|四百米是最短的长跑距离
|
||||||
|
最高的撑杆跳为11米|最高的撑杆跳为十一米
|
||||||
|
等会请在12:05请通知我|等会请在十二点零五分请通知我
|
||||||
|
23点15分开始|二十三点十五分开始
|
||||||
|
你生日那天我会送你999朵玫瑰|你生日那天我会送你九百九十九朵玫瑰
|
||||||
|
给我1双鞋我可以跳96米远|给我一双鞋我可以跳九十六米远
|
||||||
|
虽然我们的身高相差356毫米也不影响我们交往|虽然我们的身高相差三百五十六毫米也不影响我们交往
|
||||||
|
我们班的最高总分为583分|我们班的最高总分为五百八十三分
|
||||||
|
今天考试老师多扣了我21分|今天考试老师多扣了我二十一分
|
||||||
|
我量过这张桌子总长为1.37米|我量过这张桌子总长为一点三七米
|
||||||
|
乘务员身高必须超过185公分|乘务员身高必须超过一百八十五公分
|
||||||
|
这台电脑分辨率为1024|这台电脑分辨率为一零二四
|
||||||
|
手机价格不超过1500元|手机价格不超过一千五百元
|
||||||
|
101.23|一百零一点二三
|
||||||
|
123.116|一百二十三点一一六
|
||||||
|
456.147|四百五十六点一四七
|
||||||
|
0.1594|零点一五九四
|
||||||
|
3.1415|三点一四一五
|
||||||
|
0.112233|零点一一二二三三
|
||||||
|
0.1|零点一
|
||||||
|
40001.987|四万零一点九八七
|
||||||
|
56.878|五十六点八七八
|
||||||
|
0.00123|零点零零一二三
|
||||||
|
0.0001|零点零零零一
|
||||||
|
0.92015|零点九二零一五
|
||||||
|
999.0001|九百九十九点零零零一
|
||||||
|
10000.123|一万点一二三
|
||||||
|
666.555|六百六十六点五五五
|
||||||
|
444.789|四百四十四点七八九
|
||||||
|
789.666|七百八十九点六六六
|
||||||
|
0.12345|零点一二三四五
|
||||||
|
1.05649|一点零五六四九
|
||||||
|
环比上调1.86%|环比上调百分之一点八六
|
||||||
|
环比分别下跌3.46%及微涨0.70%|环比分别下跌百分之三点四六及微涨百分之零点七
|
||||||
|
单价在30000元的二手房购房个案当中|单价在三万元的二手房购房个案当中
|
||||||
|
6月仍有7%单价在30000元的房源|六月仍有百分之七单价在三万元的房源
|
||||||
|
最终也只是以总积分1分之差屈居第2|最终也只是以总积分一分之差屈居第二
|
||||||
|
中新网8月29日电今日|中新网八月二十九日电今日
|
||||||
|
自6月底呼和浩特市率先宣布取消限购后|自六月底呼和浩特市率先宣布取消限购后
|
||||||
|
仅1个多月的时间里|仅一个多月的时间里
|
||||||
|
除了北京上海广州深圳4个一线城市和三亚之外|除了北京上海广州深圳四个一线城市和三亚之外
|
||||||
|
46个限购城市当中|四十六个限购城市当中
|
||||||
|
41个已正式取消或变相放松了限购|四十一个已正式取消或变相放松了限购
|
||||||
|
其中包括对拥有一套住房并已结清相应购房贷款的家庭|其中包括对拥有一套住房并已结清相应购房贷款的家庭
|
||||||
|
这个后来被称为930新政策的措施|这个后来被称为九三零新政策的措施
|
||||||
|
今年有望超三百亿美元|今年有望超三百亿美元
|
||||||
|
就连一向看多的任志强|就连一向看多的任志强
|
||||||
|
近期也一反常态地发表看空言论|近期也一反常态地发表看空言论
|
||||||
|
985|九八五
|
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# test g2p
|
||||||
|
echo "Start test g2p."
|
||||||
|
python3 test_g2p.py --root-dir=~/datasets/BZNSYP
|
||||||
|
# test text normalization
|
||||||
|
echo "Start test text normalization."
|
||||||
|
python3 test_textnorm.py --test-file=data/textnorm_test_cases.txt
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from parakeet.frontend.cn_frontend import Frontend as cnFrontend
|
||||||
|
from parakeet.utils.error_rate import wer
|
||||||
|
from praatio import tgio
|
||||||
|
|
||||||
|
|
||||||
|
def text_cleaner(raw_text):
|
||||||
|
text = re.sub('#[1-4]|“|”|(|)', '', raw_text)
|
||||||
|
text = text.replace("…。", "。")
|
||||||
|
text = re.sub(':|;|——|……|、|…|—', ',', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def get_baker_data(root_dir):
|
||||||
|
alignment_files = sorted(
|
||||||
|
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||||
|
text_file = root_dir / "ProsodyLabeling/000001-010000.txt"
|
||||||
|
text_file = Path(text_file).expanduser()
|
||||||
|
data_dict = defaultdict(dict)
|
||||||
|
# filter out several files that have errors in annotation
|
||||||
|
exclude = {'000611', '000662', '002365', '005107'}
|
||||||
|
alignment_files = [f for f in alignment_files if f.stem not in exclude]
|
||||||
|
# biaobei 前后有 sil ,中间没有 sp
|
||||||
|
data_dict = defaultdict(dict)
|
||||||
|
for alignment_fp in alignment_files:
|
||||||
|
alignment = tgio.openTextgrid(alignment_fp)
|
||||||
|
# only with baker's annotation
|
||||||
|
utt_id = alignment.tierNameList[0].split(".")[0]
|
||||||
|
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||||
|
phones = []
|
||||||
|
for interval in intervals:
|
||||||
|
label = interval.label
|
||||||
|
# Baker has sp1 rather than sp
|
||||||
|
label = label.replace("sp1", "sp")
|
||||||
|
phones.append(label)
|
||||||
|
data_dict[utt_id]["phones"] = phones
|
||||||
|
for line in open(text_file, "r"):
|
||||||
|
if line.startswith("0"):
|
||||||
|
utt_id, raw_text = line.strip().split()
|
||||||
|
text = text_cleaner(raw_text)
|
||||||
|
if utt_id in data_dict:
|
||||||
|
data_dict[utt_id]['text'] = text
|
||||||
|
else:
|
||||||
|
pinyin = line.strip().split()
|
||||||
|
if utt_id in data_dict:
|
||||||
|
data_dict[utt_id]['pinyin'] = pinyin
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_g2p_phones(data_dict, frontend):
|
||||||
|
for utt_id in data_dict:
|
||||||
|
g2p_phones = frontend.get_phonemes(data_dict[utt_id]['text'])
|
||||||
|
data_dict[utt_id]["g2p_phones"] = g2p_phones
|
||||||
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_avg_wer(data_dict):
|
||||||
|
wer_list = []
|
||||||
|
for utt_id in data_dict:
|
||||||
|
g2p_phones = data_dict[utt_id]['g2p_phones']
|
||||||
|
# delete silence tokens in predicted phones
|
||||||
|
g2p_phones = [phn for phn in g2p_phones if phn not in {"sp", "sil"}]
|
||||||
|
gt_phones = data_dict[utt_id]['phones']
|
||||||
|
# delete silence tokens in baker phones
|
||||||
|
gt_phones = [phn for phn in gt_phones if phn not in {"sp", "sil"}]
|
||||||
|
gt_phones = " ".join(gt_phones)
|
||||||
|
g2p_phones = " ".join(g2p_phones)
|
||||||
|
single_wer = wer(gt_phones, g2p_phones)
|
||||||
|
wer_list.append(single_wer)
|
||||||
|
return sum(wer_list) / len(wer_list)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="g2p example.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="directory to baker dataset.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
root_dir = Path(args.root_dir).expanduser()
|
||||||
|
assert root_dir.is_dir()
|
||||||
|
frontend = cnFrontend()
|
||||||
|
data_dict = get_baker_data(root_dir)
|
||||||
|
data_dict = get_g2p_phones(data_dict, frontend)
|
||||||
|
avg_wer = get_avg_wer(data_dict)
|
||||||
|
print("The avg WER of g2p is:", avg_wer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer
|
||||||
|
from parakeet.utils.error_rate import cer
|
||||||
|
|
||||||
|
|
||||||
|
# delete english characters
|
||||||
|
# e.g. "你好aBC" -> "你 好"
|
||||||
|
def del_en_add_space(input: str):
|
||||||
|
output = re.sub('[a-zA-Z]', '', input)
|
||||||
|
output = [char + " " for char in output]
|
||||||
|
output = "".join(output).strip()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_avg_cer(test_file, text_normalizer):
|
||||||
|
cer_list = []
|
||||||
|
for line in open(test_file, "r"):
|
||||||
|
line = line.strip()
|
||||||
|
raw_text, gt_text = line.split("|")
|
||||||
|
textnorm_text = text_normalizer.normalize_sentence(raw_text)
|
||||||
|
gt_text = del_en_add_space(gt_text)
|
||||||
|
textnorm_text = del_en_add_space(textnorm_text)
|
||||||
|
single_cer = cer(gt_text, textnorm_text)
|
||||||
|
cer_list.append(single_cer)
|
||||||
|
return sum(cer_list) / len(cer_list)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="text normalization example.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="path of text normalization test file.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
test_file = Path(args.test_file).expanduser()
|
||||||
|
text_normalizer = TextNormalizer()
|
||||||
|
avg_cer = get_avg_cer(test_file, text_normalizer)
|
||||||
|
print("The avg CER of text normalization is:", avg_cer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -35,6 +35,14 @@ class Frontend():
|
||||||
self.g2pM_model = G2pM()
|
self.g2pM_model = G2pM()
|
||||||
self.pinyin2phone = generate_lexicon(
|
self.pinyin2phone = generate_lexicon(
|
||||||
with_tone=True, with_erhua=False)
|
with_tone=True, with_erhua=False)
|
||||||
|
self.must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"}
|
||||||
|
self.not_erhua = {
|
||||||
|
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
|
||||||
|
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
|
||||||
|
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
|
||||||
|
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
|
||||||
|
"狗儿"
|
||||||
|
}
|
||||||
|
|
||||||
def _get_initials_finals(self, word):
|
def _get_initials_finals(self, word):
|
||||||
initials = []
|
initials = []
|
||||||
|
@ -71,26 +79,31 @@ class Frontend():
|
||||||
return initials, finals
|
return initials, finals
|
||||||
|
|
||||||
# if merge_sentences, merge all sentences into one phone sequence
|
# if merge_sentences, merge all sentences into one phone sequence
|
||||||
def _g2p(self, sentences, merge_sentences=True):
|
def _g2p(self, sentences, merge_sentences=True, with_erhua=True):
|
||||||
segments = sentences
|
segments = sentences
|
||||||
phones_list = []
|
phones_list = []
|
||||||
for seg in segments:
|
for seg in segments:
|
||||||
phones = []
|
phones = []
|
||||||
seg = psg.lcut(seg)
|
seg_cut = psg.lcut(seg)
|
||||||
initials = []
|
initials = []
|
||||||
finals = []
|
finals = []
|
||||||
seg = self.tone_modifier.pre_merge_for_modify(seg)
|
seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut)
|
||||||
for word, pos in seg:
|
for word, pos in seg_cut:
|
||||||
if pos == 'eng':
|
if pos == 'eng':
|
||||||
continue
|
continue
|
||||||
sub_initials, sub_finals = self._get_initials_finals(word)
|
sub_initials, sub_finals = self._get_initials_finals(word)
|
||||||
|
|
||||||
sub_finals = self.tone_modifier.modified_tone(word, pos,
|
sub_finals = self.tone_modifier.modified_tone(word, pos,
|
||||||
sub_finals)
|
sub_finals)
|
||||||
|
if with_erhua:
|
||||||
|
sub_initials, sub_finals = self._merge_erhua(
|
||||||
|
sub_initials, sub_finals, word, pos)
|
||||||
initials.append(sub_initials)
|
initials.append(sub_initials)
|
||||||
finals.append(sub_finals)
|
finals.append(sub_finals)
|
||||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||||
initials = sum(initials, [])
|
initials = sum(initials, [])
|
||||||
finals = sum(finals, [])
|
finals = sum(finals, [])
|
||||||
|
|
||||||
for c, v in zip(initials, finals):
|
for c, v in zip(initials, finals):
|
||||||
# NOTE: post process for pypinyin outputs
|
# NOTE: post process for pypinyin outputs
|
||||||
# we discriminate i, ii and iii
|
# we discriminate i, ii and iii
|
||||||
|
@ -106,7 +119,24 @@ class Frontend():
|
||||||
phones_list = sum(phones_list, [])
|
phones_list = sum(phones_list, [])
|
||||||
return phones_list
|
return phones_list
|
||||||
|
|
||||||
def get_phonemes(self, sentence):
|
def _merge_erhua(self, initials, finals, word, pos):
|
||||||
|
if word not in self.must_erhua and (word in self.not_erhua or
|
||||||
|
pos in {"a", "j", "nr"}):
|
||||||
|
return initials, finals
|
||||||
|
new_initials = []
|
||||||
|
new_finals = []
|
||||||
|
assert len(finals) == len(word)
|
||||||
|
for i, phn in enumerate(finals):
|
||||||
|
if i == len(finals) - 1 and word[i] == "儿" and phn in {
|
||||||
|
"er2", "er5"
|
||||||
|
} and word[-2:] not in self.not_erhua and new_finals:
|
||||||
|
new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1]
|
||||||
|
else:
|
||||||
|
new_finals.append(phn)
|
||||||
|
new_initials.append(initials[i])
|
||||||
|
return new_initials, new_finals
|
||||||
|
|
||||||
|
def get_phonemes(self, sentence, with_erhua=True):
|
||||||
sentences = self.text_normalizer.normalize(sentence)
|
sentences = self.text_normalizer.normalize(sentence)
|
||||||
phonemes = self._g2p(sentences)
|
phonemes = self._g2p(sentences, with_erhua=with_erhua)
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
|
@ -29,6 +29,8 @@ UNITS = OrderedDict({
|
||||||
8: '亿',
|
8: '亿',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
|
||||||
|
|
||||||
# 分数表达式
|
# 分数表达式
|
||||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||||
|
|
||||||
|
@ -59,7 +61,17 @@ def replace_percentage(match: re.Match) -> str:
|
||||||
|
|
||||||
# 整数表达式
|
# 整数表达式
|
||||||
# 带负号或者不带负号的整数 12, -10
|
# 带负号或者不带负号的整数 12, -10
|
||||||
RE_INTEGER = re.compile(r'(-?)' r'(\d+)')
|
RE_INTEGER = re.compile(r'(-)' r'(\d+)')
|
||||||
|
|
||||||
|
|
||||||
|
def replace_negative_num(match: re.Match) -> str:
|
||||||
|
sign = match.group(1)
|
||||||
|
number = match.group(2)
|
||||||
|
sign: str = "负" if sign else ""
|
||||||
|
number: str = num2str(number)
|
||||||
|
result = f"{sign}{number}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# 编号-无符号整形
|
# 编号-无符号整形
|
||||||
# 00078
|
# 00078
|
||||||
|
@ -72,12 +84,23 @@ def replace_default_num(match: re.Match):
|
||||||
|
|
||||||
|
|
||||||
# 数字表达式
|
# 数字表达式
|
||||||
# 1. 整数: -10, 10;
|
# 纯小数
|
||||||
# 2. 浮点数: 10.2, -0.3
|
RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
|
||||||
# 3. 不带符号和整数部分的纯浮点数: .22, .38
|
# 正整数 + 量词
|
||||||
|
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几])?" + COM_QUANTIFIERS)
|
||||||
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
||||||
|
|
||||||
|
|
||||||
|
def replace_positive_quantifier(match: re.Match) -> str:
|
||||||
|
number = match.group(1)
|
||||||
|
match_2 = match.group(2)
|
||||||
|
match_2: str = match_2 if match_2 else ""
|
||||||
|
quantifiers: str = match.group(3)
|
||||||
|
number: str = num2str(number)
|
||||||
|
result = f"{number}{match_2}{quantifiers}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def replace_number(match: re.Match) -> str:
|
def replace_number(match: re.Match) -> str:
|
||||||
sign = match.group(1)
|
sign = match.group(1)
|
||||||
number = match.group(2)
|
number = match.group(2)
|
||||||
|
@ -93,7 +116,7 @@ def replace_number(match: re.Match) -> str:
|
||||||
|
|
||||||
# 范围表达式
|
# 范围表达式
|
||||||
# 12-23, 12~23
|
# 12-23, 12~23
|
||||||
RE_RANGE = re.compile(r'(\d+)[-~](\d+)')
|
RE_RANGE = re.compile(r'(\d+)[~](\d+)')
|
||||||
|
|
||||||
|
|
||||||
def replace_range(match: re.Match) -> str:
|
def replace_range(match: re.Match) -> str:
|
||||||
|
|
|
@ -25,7 +25,7 @@ from .num import verbalize_digit
|
||||||
RE_MOBILE_PHONE = re.compile(
|
RE_MOBILE_PHONE = re.compile(
|
||||||
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||||
RE_TELEPHONE = re.compile(
|
RE_TELEPHONE = re.compile(
|
||||||
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{7,8})(?!\d)")
|
||||||
|
|
||||||
|
|
||||||
def phone2str(phone_string: str, mobile=True) -> str:
|
def phone2str(phone_string: str, mobile=True) -> str:
|
||||||
|
@ -44,4 +44,8 @@ def phone2str(phone_string: str, mobile=True) -> str:
|
||||||
|
|
||||||
|
|
||||||
def replace_phone(match: re.Match) -> str:
|
def replace_phone(match: re.Match) -> str:
|
||||||
|
return phone2str(match.group(0), mobile=False)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mobile(match: re.Match) -> str:
|
||||||
return phone2str(match.group(0))
|
return phone2str(match.group(0))
|
||||||
|
|
|
@ -12,16 +12,15 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import opencc
|
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from .chronology import RE_TIME, RE_DATE, RE_DATE2
|
from .chronology import RE_TIME, RE_DATE, RE_DATE2
|
||||||
from .chronology import replace_time, replace_date, replace_date2
|
from .chronology import replace_time, replace_date, replace_date2
|
||||||
from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE
|
from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE
|
||||||
from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM
|
from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM, RE_DECIMAL_NUM, RE_POSITIVE_QUANTIFIERS
|
||||||
from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num
|
from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num, replace_negative_num, replace_positive_quantifier
|
||||||
from .phonecode import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone
|
from .phonecode import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone, replace_mobile
|
||||||
from .quantifier import RE_TEMPERATURE
|
from .quantifier import RE_TEMPERATURE
|
||||||
from .quantifier import replace_temperature
|
from .quantifier import replace_temperature
|
||||||
|
|
||||||
|
@ -29,8 +28,6 @@ from .quantifier import replace_temperature
|
||||||
class TextNormalizer():
|
class TextNormalizer():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')
|
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')
|
||||||
self._t2s_converter = opencc.OpenCC("t2s.json")
|
|
||||||
self._s2t_converter = opencc.OpenCC('s2t.json')
|
|
||||||
|
|
||||||
def _split(self, text: str) -> List[str]:
|
def _split(self, text: str) -> List[str]:
|
||||||
"""Split long text into sentences with sentence-splitting punctuations.
|
"""Split long text into sentences with sentence-splitting punctuations.
|
||||||
|
@ -48,15 +45,8 @@ class TextNormalizer():
|
||||||
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
|
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
def _tranditional_to_simplified(self, text: str) -> str:
|
|
||||||
return self._t2s_converter.convert(text)
|
|
||||||
|
|
||||||
def _simplified_to_traditional(self, text: str) -> str:
|
|
||||||
return self._s2t_converter.convert(text)
|
|
||||||
|
|
||||||
def normalize_sentence(self, sentence):
|
def normalize_sentence(self, sentence):
|
||||||
# basic character conversions
|
# basic character conversions
|
||||||
sentence = self._tranditional_to_simplified(sentence)
|
|
||||||
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
|
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
|
||||||
F2H_DIGITS).translate(F2H_SPACE)
|
F2H_DIGITS).translate(F2H_SPACE)
|
||||||
|
|
||||||
|
@ -68,8 +58,12 @@ class TextNormalizer():
|
||||||
sentence = RE_RANGE.sub(replace_range, sentence)
|
sentence = RE_RANGE.sub(replace_range, sentence)
|
||||||
sentence = RE_FRAC.sub(replace_frac, sentence)
|
sentence = RE_FRAC.sub(replace_frac, sentence)
|
||||||
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
||||||
sentence = RE_MOBILE_PHONE.sub(replace_phone, sentence)
|
sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence)
|
||||||
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
||||||
|
sentence = RE_INTEGER.sub(replace_negative_num, sentence)
|
||||||
|
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
|
||||||
|
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier,
|
||||||
|
sentence)
|
||||||
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
|
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
|
||||||
sentence = RE_NUMBER.sub(replace_number, sentence)
|
sentence = RE_NUMBER.sub(replace_number, sentence)
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,14 @@ class ToneSandhi():
|
||||||
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
|
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
|
||||||
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
|
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
|
||||||
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
|
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
|
||||||
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个'
|
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨',
|
||||||
|
'父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅',
|
||||||
|
'幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱',
|
||||||
|
'凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱',
|
||||||
|
'扫把', '惦记'
|
||||||
|
}
|
||||||
|
self.must_not_neural_tone_words = {
|
||||||
|
"男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子"
|
||||||
}
|
}
|
||||||
|
|
||||||
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
|
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
|
||||||
|
@ -66,71 +73,90 @@ class ToneSandhi():
|
||||||
# finals: ['ia1', 'i3']
|
# finals: ['ia1', 'i3']
|
||||||
def _neural_sandhi(self, word: str, pos: str,
|
def _neural_sandhi(self, word: str, pos: str,
|
||||||
finals: List[str]) -> List[str]:
|
finals: List[str]) -> List[str]:
|
||||||
|
|
||||||
|
# reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
|
||||||
|
for j, item in enumerate(word):
|
||||||
|
if j - 1 >= 0 and item == word[j - 1] and pos[
|
||||||
|
0] in {"n", "v", "a"}:
|
||||||
|
finals[j] = finals[j][:-1] + "5"
|
||||||
ge_idx = word.find("个")
|
ge_idx = word.find("个")
|
||||||
if len(word) == 1 and word in "吧呢啊嘛" and pos == 'y':
|
if len(word) >= 1 and word[-1] in "吧呢哈啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
elif len(word) == 1 and word in "的地得" and pos in {"ud", "uj", "uv"}:
|
elif len(word) >= 1 and word[-1] in "的地得":
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
# e.g. 走了, 看着, 去过
|
# e.g. 走了, 看着, 去过
|
||||||
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"}:
|
elif len(word) > 1 and word[-1] in "们子" and pos in {
|
||||||
|
"r", "n"
|
||||||
|
} and word not in self.must_not_neural_tone_words:
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
# e.g. 桌上, 地下, 家里
|
# e.g. 桌上, 地下, 家里
|
||||||
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
# e.g. 上来, 下去
|
# e.g. 上来, 下去
|
||||||
elif len(word) > 1 and word[-1] in "来去" and pos[0] in {"v"}:
|
elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
# 个做量词
|
# 个做量词
|
||||||
elif ge_idx >= 1 and word[ge_idx - 1].isnumeric():
|
elif (ge_idx >= 1 and
|
||||||
|
(word[ge_idx - 1].isnumeric() or
|
||||||
|
word[ge_idx - 1] in "几有两半多各整每做是")) or word == '个':
|
||||||
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
finals[ge_idx] = finals[ge_idx][:-1] + "5"
|
||||||
# reduplication words for n. and v. e.g. 奶奶, 试试
|
else:
|
||||||
elif len(word) >= 2 and word[-1] == word[-2] and pos[0] in {"n", "v"}:
|
if word in self.must_neural_tone_words or word[
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
|
||||||
# conventional tone5 in Chinese
|
|
||||||
elif word in self.must_neural_tone_words or word[
|
|
||||||
-2:] in self.must_neural_tone_words:
|
-2:] in self.must_neural_tone_words:
|
||||||
finals[-1] = finals[-1][:-1] + "5"
|
finals[-1] = finals[-1][:-1] + "5"
|
||||||
|
|
||||||
|
word_list = self._split_word(word)
|
||||||
|
finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]]
|
||||||
|
for i, word in enumerate(word_list):
|
||||||
|
# conventional neural in Chinese
|
||||||
|
if word in self.must_neural_tone_words or word[
|
||||||
|
-2:] in self.must_neural_tone_words:
|
||||||
|
finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
|
||||||
|
finals = sum(finals_list, [])
|
||||||
|
|
||||||
return finals
|
return finals
|
||||||
|
|
||||||
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||||
# "不" before tone4 should be bu2, e.g. 不怕
|
|
||||||
if len(word) > 1 and word[0] == "不" and finals[1][-1] == "4":
|
|
||||||
finals[0] = finals[0][:-1] + "2"
|
|
||||||
# e.g. 看不懂
|
# e.g. 看不懂
|
||||||
elif len(word) == 3 and word[1] == "不":
|
if len(word) == 3 and word[1] == "不":
|
||||||
finals[1] = finals[1][:-1] + "5"
|
finals[1] = finals[1][:-1] + "5"
|
||||||
|
else:
|
||||||
|
for i, char in enumerate(word):
|
||||||
|
# "不" before tone4 should be bu2, e.g. 不怕
|
||||||
|
if char == "不" and i + 1 < len(word) and finals[i + 1][
|
||||||
|
-1] == "4":
|
||||||
|
finals[i] = finals[i][:-1] + "2"
|
||||||
return finals
|
return finals
|
||||||
|
|
||||||
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||||
# "一" in number sequences, e.g. 一零零
|
# "一" in number sequences, e.g. 一零零, 二一零
|
||||||
if len(word) > 1 and word[0] == "一" and all(
|
if word.find("一") != -1 and all(
|
||||||
[item.isnumeric() for item in word]):
|
[item.isnumeric() for item in word if item != "一"]):
|
||||||
return finals
|
return finals
|
||||||
# "一" before tone4 should be yi2, e.g. 一段
|
|
||||||
elif len(word) > 1 and word[0] == "一" and finals[1][-1] == "4":
|
|
||||||
finals[0] = finals[0][:-1] + "2"
|
|
||||||
# "一" before non-tone4 should be yi4, e.g. 一天
|
|
||||||
elif len(word) > 1 and word[0] == "一" and finals[1][-1] != "4":
|
|
||||||
finals[0] = finals[0][:-1] + "4"
|
|
||||||
# "一" between reduplication words shold be yi5, e.g. 看一看
|
# "一" between reduplication words shold be yi5, e.g. 看一看
|
||||||
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
||||||
finals[1] = finals[1][:-1] + "5"
|
finals[1] = finals[1][:-1] + "5"
|
||||||
# when "一" is ordinal word, it should be yi1
|
# when "一" is ordinal word, it should be yi1
|
||||||
elif word.startswith("第一"):
|
elif word.startswith("第一"):
|
||||||
finals[1] = finals[1][:-1] + "1"
|
finals[1] = finals[1][:-1] + "1"
|
||||||
|
else:
|
||||||
|
for i, char in enumerate(word):
|
||||||
|
if char == "一" and i + 1 < len(word):
|
||||||
|
# "一" before tone4 should be yi2, e.g. 一段
|
||||||
|
if finals[i + 1][-1] == "4":
|
||||||
|
finals[i] = finals[i][:-1] + "2"
|
||||||
|
# "一" before non-tone4 should be yi4, e.g. 一天
|
||||||
|
else:
|
||||||
|
finals[i] = finals[i][:-1] + "4"
|
||||||
return finals
|
return finals
|
||||||
|
|
||||||
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
def _split_word(self, word):
|
||||||
if len(word) == 2 and self._all_tone_three(finals):
|
|
||||||
finals[0] = finals[0][:-1] + "2"
|
|
||||||
elif len(word) == 3:
|
|
||||||
word_list = jieba.cut_for_search(word)
|
word_list = jieba.cut_for_search(word)
|
||||||
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
|
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
|
||||||
new_word_list = []
|
new_word_list = []
|
||||||
|
|
||||||
first_subword = word_list[0]
|
first_subword = word_list[0]
|
||||||
first_begin_idx = word.find(first_subword)
|
first_begin_idx = word.find(first_subword)
|
||||||
if first_begin_idx == 0:
|
if first_begin_idx == 0:
|
||||||
|
@ -138,20 +164,25 @@ class ToneSandhi():
|
||||||
new_word_list = [first_subword, second_subword]
|
new_word_list = [first_subword, second_subword]
|
||||||
else:
|
else:
|
||||||
second_subword = word[:-len(first_subword)]
|
second_subword = word[:-len(first_subword)]
|
||||||
|
|
||||||
new_word_list = [second_subword, first_subword]
|
new_word_list = [second_subword, first_subword]
|
||||||
|
return new_word_list
|
||||||
|
|
||||||
|
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||||
|
if len(word) == 2 and self._all_tone_three(finals):
|
||||||
|
finals[0] = finals[0][:-1] + "2"
|
||||||
|
elif len(word) == 3:
|
||||||
|
word_list = self._split_word(word)
|
||||||
if self._all_tone_three(finals):
|
if self._all_tone_three(finals):
|
||||||
# disyllabic + monosyllabic, e.g. 蒙古/包
|
# disyllabic + monosyllabic, e.g. 蒙古/包
|
||||||
if len(new_word_list[0]) == 2:
|
if len(word_list[0]) == 2:
|
||||||
finals[0] = finals[0][:-1] + "2"
|
finals[0] = finals[0][:-1] + "2"
|
||||||
finals[1] = finals[1][:-1] + "2"
|
finals[1] = finals[1][:-1] + "2"
|
||||||
# monosyllabic + disyllabic, e.g. 纸/老虎
|
# monosyllabic + disyllabic, e.g. 纸/老虎
|
||||||
elif len(new_word_list[0]) == 1:
|
elif len(word_list[0]) == 1:
|
||||||
finals[1] = finals[1][:-1] + "2"
|
finals[1] = finals[1][:-1] + "2"
|
||||||
else:
|
else:
|
||||||
finals_list = [
|
finals_list = [
|
||||||
finals[:len(new_word_list[0])],
|
finals[:len(word_list[0])], finals[len(word_list[0]):]
|
||||||
finals[len(new_word_list[0]):]
|
|
||||||
]
|
]
|
||||||
if len(finals_list) == 2:
|
if len(finals_list) == 2:
|
||||||
for i, sub in enumerate(finals_list):
|
for i, sub in enumerate(finals_list):
|
||||||
|
@ -192,8 +223,7 @@ class ToneSandhi():
|
||||||
if last_word == "不":
|
if last_word == "不":
|
||||||
new_seg.append((last_word, 'd'))
|
new_seg.append((last_word, 'd'))
|
||||||
last_word = ""
|
last_word = ""
|
||||||
seg = new_seg
|
return new_seg
|
||||||
return seg
|
|
||||||
|
|
||||||
# function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
|
# function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
|
||||||
# function 2: merge single "一" and the word behind it
|
# function 2: merge single "一" and the word behind it
|
||||||
|
@ -222,9 +252,9 @@ class ToneSandhi():
|
||||||
new_seg[-1][0] = new_seg[-1][0] + word
|
new_seg[-1][0] = new_seg[-1][0] + word
|
||||||
else:
|
else:
|
||||||
new_seg.append([word, pos])
|
new_seg.append([word, pos])
|
||||||
seg = new_seg
|
return new_seg
|
||||||
return seg
|
|
||||||
|
|
||||||
|
# the first and the second words are all_tone_three
|
||||||
def _merge_continuous_three_tones(
|
def _merge_continuous_three_tones(
|
||||||
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
new_seg = []
|
new_seg = []
|
||||||
|
@ -239,21 +269,73 @@ class ToneSandhi():
|
||||||
if i - 1 >= 0 and self._all_tone_three(sub_finals_list[
|
if i - 1 >= 0 and self._all_tone_three(sub_finals_list[
|
||||||
i - 1]) and self._all_tone_three(sub_finals_list[
|
i - 1]) and self._all_tone_three(sub_finals_list[
|
||||||
i]) and not merge_last[i - 1]:
|
i]) and not merge_last[i - 1]:
|
||||||
if len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||||
|
if not self._is_reduplication(seg[i - 1][0]) and len(seg[
|
||||||
|
i - 1][0]) + len(seg[i][0]) <= 3:
|
||||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
merge_last[i] = True
|
merge_last[i] = True
|
||||||
else:
|
else:
|
||||||
new_seg.append([word, pos])
|
new_seg.append([word, pos])
|
||||||
else:
|
else:
|
||||||
new_seg.append([word, pos])
|
new_seg.append([word, pos])
|
||||||
seg = new_seg
|
|
||||||
return seg
|
return new_seg
|
||||||
|
|
||||||
|
def _is_reduplication(self, word):
|
||||||
|
return len(word) == 2 and word[0] == word[1]
|
||||||
|
|
||||||
|
# the last char of first word and the first char of second word is tone_three
|
||||||
|
def _merge_continuous_three_tones_2(
|
||||||
|
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
|
new_seg = []
|
||||||
|
sub_finals_list = [
|
||||||
|
lazy_pinyin(
|
||||||
|
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||||
|
for (word, pos) in seg
|
||||||
|
]
|
||||||
|
assert len(sub_finals_list) == len(seg)
|
||||||
|
merge_last = [False] * len(seg)
|
||||||
|
for i, (word, pos) in enumerate(seg):
|
||||||
|
if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \
|
||||||
|
merge_last[i - 1]:
|
||||||
|
# if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
|
||||||
|
if not self._is_reduplication(seg[i - 1][0]) and len(seg[
|
||||||
|
i - 1][0]) + len(seg[i][0]) <= 3:
|
||||||
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
|
merge_last[i] = True
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
return new_seg
|
||||||
|
|
||||||
|
def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
|
new_seg = []
|
||||||
|
for i, (word, pos) in enumerate(seg):
|
||||||
|
if i - 1 >= 0 and word == "儿":
|
||||||
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
return new_seg
|
||||||
|
|
||||||
|
def _merge_reduplication(
|
||||||
|
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
|
new_seg = []
|
||||||
|
for i, (word, pos) in enumerate(seg):
|
||||||
|
if new_seg and word == new_seg[-1][0]:
|
||||||
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
return new_seg
|
||||||
|
|
||||||
def pre_merge_for_modify(
|
def pre_merge_for_modify(
|
||||||
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||||
seg = self._merge_bu(seg)
|
seg = self._merge_bu(seg)
|
||||||
seg = self._merge_yi(seg)
|
seg = self._merge_yi(seg)
|
||||||
|
seg = self._merge_reduplication(seg)
|
||||||
seg = self._merge_continuous_three_tones(seg)
|
seg = self._merge_continuous_three_tones(seg)
|
||||||
|
seg = self._merge_continuous_three_tones_2(seg)
|
||||||
|
seg = self._merge_er(seg)
|
||||||
return seg
|
return seg
|
||||||
|
|
||||||
def modified_tone(self, word: str, pos: str,
|
def modified_tone(self, word: str, pos: str,
|
||||||
|
|
|
@ -0,0 +1,239 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""This module provides functions to calculate error rate in different level.
|
||||||
|
e.g. wer for word-level, cer for char-level.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = ['word_errors', 'char_errors', 'wer', 'cer']
|
||||||
|
|
||||||
|
|
||||||
|
def _levenshtein_distance(ref, hyp):
|
||||||
|
"""Levenshtein distance is a string metric for measuring the difference
|
||||||
|
between two sequences. Informally, the levenshtein disctance is defined as
|
||||||
|
the minimum number of single-character edits (substitutions, insertions or
|
||||||
|
deletions) required to change one word into the other. We can naturally
|
||||||
|
extend the edits to word level when calculate levenshtein disctance for
|
||||||
|
two sentences.
|
||||||
|
"""
|
||||||
|
m = len(ref)
|
||||||
|
n = len(hyp)
|
||||||
|
|
||||||
|
# special case
|
||||||
|
if ref == hyp:
|
||||||
|
return 0
|
||||||
|
if m == 0:
|
||||||
|
return n
|
||||||
|
if n == 0:
|
||||||
|
return m
|
||||||
|
|
||||||
|
if m < n:
|
||||||
|
ref, hyp = hyp, ref
|
||||||
|
m, n = n, m
|
||||||
|
|
||||||
|
# use O(min(m, n)) space
|
||||||
|
distance = np.zeros((2, n + 1), dtype=np.int32)
|
||||||
|
|
||||||
|
# initialize distance matrix
|
||||||
|
for j in range(n + 1):
|
||||||
|
distance[0][j] = j
|
||||||
|
|
||||||
|
# calculate levenshtein distance
|
||||||
|
for i in range(1, m + 1):
|
||||||
|
prev_row_idx = (i - 1) % 2
|
||||||
|
cur_row_idx = i % 2
|
||||||
|
distance[cur_row_idx][0] = i
|
||||||
|
for j in range(1, n + 1):
|
||||||
|
if ref[i - 1] == hyp[j - 1]:
|
||||||
|
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
|
||||||
|
else:
|
||||||
|
s_num = distance[prev_row_idx][j - 1] + 1
|
||||||
|
i_num = distance[cur_row_idx][j - 1] + 1
|
||||||
|
d_num = distance[prev_row_idx][j] + 1
|
||||||
|
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
|
||||||
|
|
||||||
|
return distance[m % 2][n]
|
||||||
|
|
||||||
|
|
||||||
|
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
|
||||||
|
"""Compute the levenshtein distance between reference sequence and
|
||||||
|
hypothesis sequence in word-level.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reference : str
|
||||||
|
The reference sentence.
|
||||||
|
hypothesis : str
|
||||||
|
The hypothesis sentence.
|
||||||
|
ignore_case : bool
|
||||||
|
Whether case-sensitive or not.
|
||||||
|
delimiter : char(str)
|
||||||
|
Delimiter of input sentences.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
list
|
||||||
|
Levenshtein distance and word number of reference sentence.
|
||||||
|
"""
|
||||||
|
if ignore_case:
|
||||||
|
reference = reference.lower()
|
||||||
|
hypothesis = hypothesis.lower()
|
||||||
|
|
||||||
|
ref_words = list(filter(None, reference.split(delimiter)))
|
||||||
|
hyp_words = list(filter(None, hypothesis.split(delimiter)))
|
||||||
|
|
||||||
|
edit_distance = _levenshtein_distance(ref_words, hyp_words)
|
||||||
|
return float(edit_distance), len(ref_words)
|
||||||
|
|
||||||
|
|
||||||
|
def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
|
||||||
|
"""Compute the levenshtein distance between reference sequence and
|
||||||
|
hypothesis sequence in char-level.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reference: str
|
||||||
|
The reference sentence.
|
||||||
|
hypothesis: str
|
||||||
|
The hypothesis sentence.
|
||||||
|
ignore_case: bool
|
||||||
|
Whether case-sensitive or not.
|
||||||
|
remove_space: bool
|
||||||
|
Whether remove internal space characters
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
list
|
||||||
|
Levenshtein distance and length of reference sentence.
|
||||||
|
"""
|
||||||
|
if ignore_case:
|
||||||
|
reference = reference.lower()
|
||||||
|
hypothesis = hypothesis.lower()
|
||||||
|
|
||||||
|
join_char = ' '
|
||||||
|
if remove_space:
|
||||||
|
join_char = ''
|
||||||
|
|
||||||
|
reference = join_char.join(list(filter(None, reference.split(' '))))
|
||||||
|
hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
|
||||||
|
|
||||||
|
edit_distance = _levenshtein_distance(reference, hypothesis)
|
||||||
|
return float(edit_distance), len(reference)
|
||||||
|
|
||||||
|
|
||||||
|
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
|
||||||
|
"""Calculate word error rate (WER). WER compares reference text and
|
||||||
|
hypothesis text in word-level. WER is defined as:
|
||||||
|
.. math::
|
||||||
|
WER = (Sw + Dw + Iw) / Nw
|
||||||
|
where
|
||||||
|
.. code-block:: text
|
||||||
|
Sw is the number of words subsituted,
|
||||||
|
Dw is the number of words deleted,
|
||||||
|
Iw is the number of words inserted,
|
||||||
|
Nw is the number of words in the reference
|
||||||
|
We can use levenshtein distance to calculate WER. Please draw an attention
|
||||||
|
that empty items will be removed when splitting sentences by delimiter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reference: str
|
||||||
|
The reference sentence.
|
||||||
|
|
||||||
|
hypothesis: str
|
||||||
|
The hypothesis sentence.
|
||||||
|
ignore_case: bool
|
||||||
|
Whether case-sensitive or not.
|
||||||
|
delimiter: char
|
||||||
|
Delimiter of input sentences.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
float
|
||||||
|
Word error rate.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
----------
|
||||||
|
ValueError
|
||||||
|
If word number of reference is zero.
|
||||||
|
"""
|
||||||
|
edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
|
||||||
|
delimiter)
|
||||||
|
|
||||||
|
if ref_len == 0:
|
||||||
|
raise ValueError("Reference's word number should be greater than 0.")
|
||||||
|
|
||||||
|
wer = float(edit_distance) / ref_len
|
||||||
|
return wer
|
||||||
|
|
||||||
|
|
||||||
|
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
|
||||||
|
"""Calculate charactor error rate (CER). CER compares reference text and
|
||||||
|
hypothesis text in char-level. CER is defined as:
|
||||||
|
.. math::
|
||||||
|
CER = (Sc + Dc + Ic) / Nc
|
||||||
|
where
|
||||||
|
.. code-block:: text
|
||||||
|
Sc is the number of characters substituted,
|
||||||
|
Dc is the number of characters deleted,
|
||||||
|
Ic is the number of characters inserted
|
||||||
|
Nc is the number of characters in the reference
|
||||||
|
We can use levenshtein distance to calculate CER. Chinese input should be
|
||||||
|
encoded to unicode. Please draw an attention that the leading and tailing
|
||||||
|
space characters will be truncated and multiple consecutive space
|
||||||
|
characters in a sentence will be replaced by one space character.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
reference: str
|
||||||
|
The reference sentence.
|
||||||
|
hypothesis: str
|
||||||
|
The hypothesis sentence.
|
||||||
|
ignore_case: bool
|
||||||
|
Whether case-sensitive or not.
|
||||||
|
remove_space: bool
|
||||||
|
Whether remove internal space characters
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
float
|
||||||
|
Character error rate.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
----------
|
||||||
|
ValueError
|
||||||
|
If the reference length is zero.
|
||||||
|
"""
|
||||||
|
edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
|
||||||
|
remove_space)
|
||||||
|
|
||||||
|
if ref_len == 0:
|
||||||
|
raise ValueError("Length of reference should be greater than 0.")
|
||||||
|
|
||||||
|
cer = float(edit_distance) / ref_len
|
||||||
|
return cer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
reference = [
|
||||||
|
'j', 'iou4', 'zh', 'e4', 'iang5', 'x', 'v2', 'b', 'o1', 'k', 'ai1',
|
||||||
|
'sh', 'iii3', 'l', 'e5', 'b', 'ei3', 'p', 'iao1', 'sh', 'eng1', 'ia2'
|
||||||
|
]
|
||||||
|
hypothesis = [
|
||||||
|
'j', 'iou4', 'zh', 'e4', 'iang4', 'x', 'v2', 'b', 'o1', 'k', 'ai1',
|
||||||
|
'sh', 'iii3', 'l', 'e5', 'b', 'ei3', 'p', 'iao1', 'sh', 'eng1', 'ia2'
|
||||||
|
]
|
||||||
|
reference = " ".join(reference)
|
||||||
|
hypothesis = " ".join(hypothesis)
|
||||||
|
print(wer(reference, hypothesis))
|
5
setup.py
5
setup.py
|
@ -71,10 +71,13 @@ setup_info = dict(
|
||||||
'pypinyin',
|
'pypinyin',
|
||||||
'webrtcvad',
|
'webrtcvad',
|
||||||
'g2pM',
|
'g2pM',
|
||||||
'praatio',
|
'praatio~=4.1',
|
||||||
"h5py",
|
"h5py",
|
||||||
"timer",
|
"timer",
|
||||||
'jsonlines',
|
'jsonlines',
|
||||||
|
'pyworld',
|
||||||
|
'typeguard',
|
||||||
|
'jieba',
|
||||||
],
|
],
|
||||||
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue