diff --git a/parakeet/frontend/phonectic.py b/parakeet/frontend/phonectic.py index 6f0de1d..2b1e2ae 100644 --- a/parakeet/frontend/phonectic.py +++ b/parakeet/frontend/phonectic.py @@ -39,6 +39,9 @@ class Phonetics(ABC): class English(Phonetics): + """ Normalize the input text sequence and convert into pronunciation id sequence. + """ + def __init__(self): self.backend = G2p() self.phonemes = list(self.backend.phonemes) @@ -46,6 +49,18 @@ class English(Phonetics): self.vocab = Vocab(self.phonemes + self.punctuations) def phoneticize(self, sentence): + """ Normalize the input text sequence and convert it into pronunciation sequence. + + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + List[str] + The list of pronunciation sequence. + """ start = self.vocab.start_symbol end = self.vocab.end_symbol phonemes = ([] if start is None else [start]) \ @@ -54,6 +69,18 @@ class English(Phonetics): return phonemes def numericalize(self, phonemes): + """ Convert pronunciation sequence into pronunciation id sequence. + + Parameters + ----------- + phonemes: List[str] + The list of pronunciation sequence. + + Returns + ---------- + List[int] + The list of pronunciation id sequence. + """ ids = [ self.vocab.lookup(item) for item in phonemes if item in self.vocab.stoi @@ -61,17 +88,46 @@ class English(Phonetics): return ids def reverse(self, ids): + """ Reverse the list of pronunciation id sequence to a list of pronunciation sequence. + + Parameters + ----------- + ids: List[int] + The list of pronunciation id sequence. + + Returns + ---------- + List[str] + The list of pronunciation sequence. + """ return [self.vocab.reverse(i) for i in ids] def __call__(self, sentence): + """ Convert the input text sequence into pronunciation id sequence. + + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + List[str] + The list of pronunciation id sequence. + """ return self.numericalize(self.phoneticize(sentence)) @property def vocab_size(self): + """ Vocab size. + """ return len(self.vocab) class EnglishCharacter(Phonetics): + """ Normalize the input text sequence and convert it into character id sequence. + """ + def __init__(self): self.backend = G2p() self.graphemes = list(self.backend.graphemes) @@ -79,10 +135,34 @@ class EnglishCharacter(Phonetics): self.vocab = Vocab(self.graphemes + self.punctuations) def phoneticize(self, sentence): + """ Normalize the input text sequence. + + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + str + A text sequence after normalize. + """ words = normalize(sentence) return words def numericalize(self, sentence): + """ Convert a text sequence into ids. + + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + List[int] + List of a character id sequence. + """ ids = [ self.vocab.lookup(item) for item in sentence if item in self.vocab.stoi @@ -90,17 +170,46 @@ class EnglishCharacter(Phonetics): return ids def reverse(self, ids): + """ Convert a character id sequence into text. + + Parameters + ----------- + ids: List[int] + List of a character id sequence. + + Returns + ---------- + str + The input text sequence. + + """ return [self.vocab.reverse(i) for i in ids] def __call__(self, sentence): + """ Normalize the input text sequence and convert it into character id sequence. + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + List[int] + List of a character id sequence. + """ return self.numericalize(self.phoneticize(sentence)) @property def vocab_size(self): + """ Vocab size. + """ return len(self.vocab) class Chinese(Phonetics): + """Normalize Chinese text sequence and convert it into ids. + """ + def __init__(self): self.opencc_backend = OpenCC('t2s.json') self.backend = G2pM() @@ -115,6 +224,18 @@ class Chinese(Phonetics): return list(all_syllables) def phoneticize(self, sentence): + """ Normalize the input text sequence and convert it into pronunciation sequence. + + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + List[str] + The list of pronunciation sequence. + """ simplified = self.opencc_backend.convert(sentence) phonemes = self.backend(simplified) start = self.vocab.start_symbol @@ -136,15 +257,53 @@ class Chinese(Phonetics): return cleaned_phonemes def numericalize(self, phonemes): + """ Convert pronunciation sequence into pronunciation id sequence. + + Parameters + ----------- + phonemes: List[str] + The list of pronunciation sequence. + + Returns + ---------- + List[int] + The list of pronunciation id sequence. + """ ids = [self.vocab.lookup(item) for item in phonemes] return ids def __call__(self, sentence): + """ Convert the input text sequence into pronunciation id sequence. + + Parameters + ----------- + sentence: str + The input text sequence. + + Returns + ---------- + List[str] + The list of pronunciation id sequence. + """ return self.numericalize(self.phoneticize(sentence)) @property def vocab_size(self): + """ Vocab size. + """ return len(self.vocab) def reverse(self, ids): + """ Reverse the list of pronunciation id sequence to a list of pronunciation sequence. + + Parameters + ----------- + ids: List[int] + The list of pronunciation id sequence. + + Returns + ---------- + List[str] + The list of pronunciation sequence. + """ return [self.vocab.reverse(i) for i in ids] diff --git a/parakeet/frontend/vocab.py b/parakeet/frontend/vocab.py index ba961eb..a56cfb8 100644 --- a/parakeet/frontend/vocab.py +++ b/parakeet/frontend/vocab.py @@ -1,32 +1,64 @@ -from typing import Dict, Iterable, List -from ruamel import yaml -from collections import OrderedDict +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Iterable, List +from collections import OrderedDict __all__ = ["Vocab"] class Vocab(object): - def __init__(self, symbols: Iterable[str], - padding_symbol="", - unk_symbol="", - start_symbol="", - end_symbol=""): + """ Vocabulary. + + Parameters + ----------- + symbols: Iterable[str] + Common symbols. + + padding_symbol: str, optional + Symbol for pad. Defaults to "". + + unk_symbol: str, optional + Symbol for unknow. Defaults to "" + + start_symbol: str, optional + Symbol for start. Defaults to "" + + end_symbol: str, optional + Symbol for end. Defaults to "" + """ + + def __init__(self, + symbols: Iterable[str], + padding_symbol="", + unk_symbol="", + start_symbol="", + end_symbol=""): self.special_symbols = OrderedDict() for i, item in enumerate( [padding_symbol, unk_symbol, start_symbol, end_symbol]): if item: self.special_symbols[item] = len(self.special_symbols) - + self.padding_symbol = padding_symbol self.unk_symbol = unk_symbol self.start_symbol = start_symbol self.end_symbol = end_symbol - - + self.stoi = OrderedDict() self.stoi.update(self.special_symbols) - + for i, s in enumerate(symbols): if s not in self.stoi: self.stoi[s] = len(self.stoi) @@ -34,49 +66,66 @@ class Vocab(object): def __len__(self): return len(self.stoi) - + @property def num_specials(self): + """ The number of special symbols. + """ return len(self.special_symbols) # special tokens @property def padding_index(self): + """ The index of padding symbol + """ return self.stoi.get(self.padding_symbol, -1) @property def unk_index(self): + """The index of unknow symbol. + """ return self.stoi.get(self.unk_symbol, -1) @property def start_index(self): + """The index of start symbol. + """ return self.stoi.get(self.start_symbol, -1) @property def end_index(self): + """ The index of end symbol. + """ return self.stoi.get(self.end_symbol, -1) - + def __repr__(self): fmt = "Vocab(size: {},\nstoi:\n{})" return fmt.format(len(self), self.stoi) - + def __str__(self): return self.__repr__() - + def lookup(self, symbol): + """ The index that symbol correspond. + """ return self.stoi[symbol] - + def reverse(self, index): + """ The symbol thar index cottespond. + """ return self.itos[index] - + def add_symbol(self, symbol): + """ Add a new symbol in vocab. + """ if symbol in self.stoi: - return + return N = len(self.stoi) self.stoi[symbol] = N self.itos[N] = symbol - + def add_symbols(self, symbols): + """ Add multiple symbols in vocab. + """ for symbol in symbols: self.add_symbol(symbol) -