Merge pull request #145 from yt605155624/add_typehint

add traditional and simplified Chinese conversion and add typehint fo…
This commit is contained in:
Hui Zhang 2021-08-19 05:30:23 -05:00 committed by GitHub
commit b4b9171250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 16 deletions

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import Dict
from typing import List
import numpy as np import numpy as np
import paddle import paddle
@ -35,7 +37,7 @@ class Frontend():
for tone, id in tone_id: for tone, id in tone_id:
self.vocab_tones[tone] = int(id) self.vocab_tones[tone] = int(id)
def _p2id(self, phonemes): def _p2id(self, phonemes: List[str]) -> np.array:
# replace unk phone with sp # replace unk phone with sp
phonemes = [ phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes phn if phn in self.vocab_phones else "sp" for phn in phonemes
@ -43,13 +45,14 @@ class Frontend():
phone_ids = [self.vocab_phones[item] for item in phonemes] phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64) return np.array(phone_ids, np.int64)
def _t2id(self, tones): def _t2id(self, tones: List[str]) -> np.array:
# replace unk phone with sp # replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones] tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones] tone_ids = [self.vocab_tones[item] for item in tones]
return np.array(tone_ids, np.int64) return np.array(tone_ids, np.int64)
def _get_phone_tone(self, phonemes, get_tone_ids=False): def _get_phone_tone(self, phonemes: List[str],
get_tone_ids: bool=False) -> List[List[str]]:
phones = [] phones = []
tones = [] tones = []
if get_tone_ids and self.vocab_tones: if get_tone_ids and self.vocab_tones:
@ -88,7 +91,11 @@ class Frontend():
phones.append(phone) phones.append(phone)
return phones, tones return phones, tones
def get_input_ids(self, sentence, merge_sentences=True, get_tone_ids=False): def get_input_ids(
self,
sentence: str,
merge_sentences: bool=True,
get_tone_ids: bool=False) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.frontend.get_phonemes( phonemes = self.frontend.get_phonemes(
sentence, merge_sentences=merge_sentences) sentence, merge_sentences=merge_sentences)
result = {} result = {}

View File

@ -15,6 +15,6 @@ Run the command below to get the results of test.
```bash ```bash
./run.sh ./run.sh
``` ```
The `avg WER` of g2p is: 0.027124048652822204 The `avg WER` of g2p is: 0.027495061517943988
The `avg CER` of text normalization is: 0.0061629764893859846 The `avg CER` of text normalization is: 0.0061629764893859846

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
from typing import List
import jieba.posseg as psg import jieba.posseg as psg
from g2pM import G2pM from g2pM import G2pM
@ -43,7 +44,7 @@ class Frontend():
"狗儿" "狗儿"
} }
def _get_initials_finals(self, word): def _get_initials_finals(self, word: str) -> List[List[str]]:
initials = [] initials = []
finals = [] finals = []
if self.g2p_model == "pypinyin": if self.g2p_model == "pypinyin":
@ -78,7 +79,10 @@ class Frontend():
return initials, finals return initials, finals
# if merge_sentences, merge all sentences into one phone sequence # if merge_sentences, merge all sentences into one phone sequence
def _g2p(self, sentences, merge_sentences=True, with_erhua=True): def _g2p(self,
sentences: List[str],
merge_sentences: bool=True,
with_erhua: bool=True) -> List[List[str]]:
segments = sentences segments = sentences
phones_list = [] phones_list = []
for seg in segments: for seg in segments:
@ -120,7 +124,11 @@ class Frontend():
phones_list.append(merge_list) phones_list.append(merge_list)
return phones_list return phones_list
def _merge_erhua(self, initials, finals, word, pos): def _merge_erhua(self,
initials: List[str],
finals: List[str],
word: str,
pos: str) -> List[List[str]]:
if word not in self.must_erhua and (word in self.not_erhua or if word not in self.must_erhua and (word in self.not_erhua or
pos in {"a", "j", "nr"}): pos in {"a", "j", "nr"}):
return initials, finals return initials, finals
@ -137,7 +145,10 @@ class Frontend():
new_initials.append(initials[i]) new_initials.append(initials[i])
return new_initials, new_finals return new_initials, new_finals
def get_phonemes(self, sentence, merge_sentences=True, with_erhua=True): def get_phonemes(self,
sentence: str,
merge_sentences: bool=True,
with_erhua: bool=True) -> List[List[str]]:
sentences = self.text_normalizer.normalize(sentence) sentences = self.text_normalizer.normalize(sentence)
phonemes = self._g2p( phonemes = self._g2p(
sentences, merge_sentences=merge_sentences, with_erhua=with_erhua) sentences, merge_sentences=merge_sentences, with_erhua=with_erhua)

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,7 @@
import re import re
from typing import List from typing import List
from .char_convert import tranditional_to_simplified
from .chronology import RE_DATE from .chronology import RE_DATE
from .chronology import RE_DATE2 from .chronology import RE_DATE2
from .chronology import RE_TIME from .chronology import RE_TIME
@ -66,8 +67,9 @@ class TextNormalizer():
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
return sentences return sentences
def normalize_sentence(self, sentence): def normalize_sentence(self, sentence: str) -> str:
# basic character conversions # basic character conversions
sentence = tranditional_to_simplified(sentence)
sentence = sentence.translate(F2H_ASCII_LETTERS).translate( sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
F2H_DIGITS).translate(F2H_SPACE) F2H_DIGITS).translate(F2H_SPACE)
@ -90,7 +92,7 @@ class TextNormalizer():
return sentence return sentence
def normalize(self, text): def normalize(self, text: str) -> List[str]:
sentences = self._split(text) sentences = self._split(text)
sentences = [self.normalize_sentence(sent) for sent in sentences] sentences = [self.normalize_sentence(sent) for sent in sentences]
return sentences return sentences

View File

@ -114,7 +114,6 @@ class ToneSandhi():
-2:] in self.must_neural_tone_words: -2:] in self.must_neural_tone_words:
finals_list[i][-1] = finals_list[i][-1][:-1] + "5" finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
finals = sum(finals_list, []) finals = sum(finals_list, [])
return finals return finals
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
@ -151,11 +150,9 @@ class ToneSandhi():
finals[i] = finals[i][:-1] + "4" finals[i] = finals[i][:-1] + "4"
return finals return finals
def _split_word(self, word): def _split_word(self, word: str) -> List[str]:
word_list = jieba.cut_for_search(word) word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False) word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
new_word_list = []
first_subword = word_list[0] first_subword = word_list[0]
first_begin_idx = word.find(first_subword) first_begin_idx = word.find(first_subword)
if first_begin_idx == 0: if first_begin_idx == 0:
@ -280,7 +277,7 @@ class ToneSandhi():
return new_seg return new_seg
def _is_reduplication(self, word): def _is_reduplication(self, word: str) -> bool:
return len(word) == 2 and word[0] == word[1] return len(word) == 2 and word[0] == word[1]
# the last char of first word and the first char of second word is tone_three # the last char of first word and the first char of second word is tone_three