add phonetics & vocab & punctuation

This commit is contained in:
iclementine 2020-10-20 16:06:11 +08:00
parent c1e0aecdde
commit 580655f33f
20 changed files with 228 additions and 553 deletions

View File

@ -14,4 +14,4 @@
__version__ = "0.0.0" __version__ = "0.0.0"
from . import data, g2p, models, modules from parakeet import data, frontend, models, modules

View File

@ -0,0 +1,3 @@
from parakeet.frontend.vocab import *
from parakeet.frontend.phonectic import *
from parakeet.frontend.punctuation import *

View File

View File

@ -0,0 +1,3 @@
# number expansion is not that easy
import num2words
import inflect

View File

@ -0,0 +1,24 @@
def full2half_width(ustr):
half = []
for u in ustr:
num = ord(u)
if num == 0x3000: # 全角空格变半角
num = 32
elif 0xFF01 <= num <= 0xFF5E:
num -= 0xfee0
u = chr(num)
half.append(u)
return ''.join(half)
def half2full_width(ustr):
full = []
for u in ustr:
num = ord(u)
if num == 32: # 半角空格变全角
num = 0x3000
elif 0x21 <= num <= 0x7E:
num += 0xfee0
u = chr(num) # to unicode
full.append(u)
return ''.join(full)

View File

@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from typing import Union
from g2p_en import G2p
from g2pM import G2pM
from parakeet.frontend import Vocab
from opencc import OpenCC
from parakeet.frontend.punctuation import get_punctuations
class Phonetics(ABC):
@abstractmethod
def __call__(self, sentence):
pass
@abstractmethod
def phoneticize(self, sentence):
pass
@abstractmethod
def numericalize(self, phonemes):
pass
class English(Phonetics):
def __init__(self):
self.backend = G2p()
self.phonemes = list(self.backend.phonemes)
self.punctuations = get_punctuations("en")
self.vocab = Vocab(self.phonemes + self.punctuations)
def phoneticize(self, sentence):
return self.backend(sentence)
def numericalize(self, phonemes):
ids = [self.vocab.lookup(item) for item in phonemes if item in self.vocab.stoi]
return ids
def reverse(self, ids):
return [self.vocab.reverse(i) for i in ids]
def __call__(self, sentence):
return self.numericalize(self.phoneticize(sentence))
def vocab_size(self):
return len(self.vocab)
class Chinese(Phonetics):
def __init__(self):
self.opencc_backend = OpenCC('t2s.json')
self.backend = G2pM()
self.phonemes = self._get_all_syllables()
self.punctuations = get_punctuations("cn")
self.vocab = Vocab(self.phonemes + self.punctuations)
def _get_all_syllables(self):
all_syllables = set([syllable for k, v in self.backend.cedict.items() for syllable in v])
return list(all_syllables)
def phoneticize(self, sentence):
simplified = self.opencc_backend.convert(sentence)
phonemes = self.backend(simplified)
return self._filter_symbols(phonemes)
def _filter_symbols(self, phonemes):
cleaned_phonemes = []
for item in phonemes:
if item in self.vocab.stoi:
cleaned_phonemes.append(item)
else:
for char in item:
if char in self.vocab.stoi:
cleaned_phonemes.append(char)
return cleaned_phonemes
def numericalize(self, phonemes):
ids = [self.vocab.lookup(item) for item in phonemes]
return ids
def __call__(self, sentence):
return self.numericalize(self.phoneticize(sentence))
def vocab_size(self):
return len(self.vocab)
def reverse(self, ids):
return [self.vocab.reverse(i) for i in ids]

View File

@ -0,0 +1,33 @@
import abc
import string
__all__ = ["get_punctuations"]
EN_PUNCT = [
" ",
"-",
"...",
",",
".",
"?",
"!",
]
CN_PUNCT = [
"",
"",
"",
"",
"",
"",
""
]
def get_punctuations(lang):
if lang == "en":
return EN_PUNCT
elif lang == "cn":
return CN_PUNCT
else:
raise ValueError(f"language {lang} Not supported")

View File

@ -0,0 +1,79 @@
from typing import Dict, Iterable, List
from ruamel import yaml
from collections import OrderedDict
class Vocab(object):
def __init__(self, symbols: Iterable[str],
padding_symbol="<pad>",
unk_symbol="<unk>",
start_symbol="<s>",
end_symbol="</s>"):
self.special_symbols = OrderedDict()
for i, item in enumerate(
[padding_symbol, unk_symbol, start_symbol, end_symbol]):
if item:
self.special_symbols[item] = len(self.special_symbols)
self.padding_symbol = padding_symbol
self.unk_symbol = unk_symbol
self.start_symbol = start_symbol
self.end_symbol = end_symbol
self.stoi = OrderedDict()
self.stoi.update(self.special_symbols)
N = len(self.special_symbols)
for i, s in enumerate(symbols):
if s not in self.stoi:
self.stoi[s] = N +i
self.itos = {v: k for k, v in self.stoi.items()}
def __len__(self):
return len(self.stoi)
@property
def num_specials(self):
return len(self.special_symbols)
# special tokens
@property
def padding_index(self):
return self.stoi.get(self.padding_symbol, -1)
@property
def unk_index(self):
return self.stoi.get(self.unk_symbol, -1)
@property
def start_index(self):
return self.stoi.get(self.start_symbol, -1)
@property
def end_index(self):
return self.stoi.get(self.end_symbol, -1)
def __repr__(self):
fmt = "Vocab(size: {},\nstoi:\n{})"
return fmt.format(len(self), self.stoi)
def __str__(self):
return self.__repr__()
def lookup(self, symbol):
return self.stoi[symbol]
def reverse(self, index):
return self.itos[index]
def add_symbol(self, symbol):
if symbol in self.stoi:
return
N = len(self.stoi)
self.stoi[symbol] = N
self.itos[N] = symbol
def add_symbols(self, symbols):
for symbol in symbols:
self.add_symbol(symbol)

View File

@ -1,32 +0,0 @@
# coding: utf-8
"""Text processing frontend
All frontend module should have the following functions:
- text_to_sequence(text, p)
- sequence_to_text(sequence)
and the property:
- n_vocab
"""
from . import en
# optinoal Japanese frontend
try:
from . import jp
except ImportError:
jp = None
try:
from . import ko
except ImportError:
ko = None
# if you are going to use the frontend, you need to modify _characters in symbol.py:
# _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + '¡¿ñáéíóúÁÉÍÓÚÑ'
try:
from . import es
except ImportError:
es = None

View File

@ -1,34 +0,0 @@
# coding: utf-8
from ..text.symbols import symbols
from ..text import sequence_to_text
import nltk
from random import random
n_vocab = len(symbols)
_arpabet = nltk.corpus.cmudict.dict()
def _maybe_get_arpabet(word, p):
try:
phonemes = _arpabet[word][0]
phonemes = " ".join(phonemes)
except KeyError:
return word
return '{%s}' % phonemes if random() < p else word
def mix_pronunciation(text, p):
text = ' '.join(_maybe_get_arpabet(word, p) for word in text.split(' '))
return text
def text_to_sequence(text, p=0.0):
if p >= 0:
text = mix_pronunciation(text, p)
from ..text import text_to_sequence
text = text_to_sequence(text, ["english_cleaners"])
return text

View File

@ -1,14 +0,0 @@
# coding: utf-8
from ..text.symbols import symbols
from ..text import sequence_to_text
import nltk
from random import random
n_vocab = len(symbols)
def text_to_sequence(text, p=0.0):
from ..text import text_to_sequence
text = text_to_sequence(text, ["basic_cleaners"])
return text

View File

@ -1,77 +0,0 @@
# coding: utf-8
import MeCab
import jaconv
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def _yomi(mecab_result):
tokens = []
yomis = []
for line in mecab_result.split("\n")[:-1]:
s = line.split("\t")
if len(s) == 1:
break
token, rest = s
rest = rest.split(",")
tokens.append(token)
yomi = rest[7] if len(rest) > 7 else None
yomi = None if yomi == "*" else yomi
yomis.append(yomi)
return tokens, yomis
def _mix_pronunciation(tokens, yomis, p):
return "".join(yomis[idx]
if yomis[idx] is not None and random() < p else tokens[idx]
for idx in range(len(tokens)))
def mix_pronunciation(text, p):
global _tagger
if _tagger is None:
_tagger = MeCab.Tagger("")
tokens, yomis = _yomi(_tagger.parse(text))
return _mix_pronunciation(tokens, yomis, p)
def add_punctuation(text):
last = text[-1]
if last not in [".", ",", "", "", "", "", "!", "?"]:
text = text + ""
return text
def normalize_delimitor(text):
text = text.replace(",", "")
text = text.replace(".", "")
text = text.replace("", "")
text = text.replace("", "")
return text
def text_to_sequence(text, p=0.0):
for c in [" ", " ", "", "", "", "", "", "", "", "", "", "(", ")"]:
text = text.replace(c, "")
text = text.replace("!", "")
text = text.replace("?", "")
text = normalize_delimitor(text)
text = jaconv.normalize(text)
if p > 0:
text = mix_pronunciation(text, p)
text = jaconv.hira2kata(text)
text = add_punctuation(text)
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)

View File

@ -1,17 +0,0 @@
# coding: utf-8
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def text_to_sequence(text, p=0.0):
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)

View File

@ -1,89 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from . import cleaners
from .symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(
_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'

View File

@ -1,110 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def add_punctuation(text):
if len(text) == 0:
return text
if text[-1] not in '!,.:;?':
text = text + '.' # without this decoder is confused when to output EOS
return text
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
#text = add_punctuation(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text

View File

@ -1,78 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
'Y', 'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {
word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)

View File

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(
num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

View File

@ -1,30 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
from .cmudict import valid_symbols
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet