diff --git a/examples/fastspeech2/baker/README.md b/examples/fastspeech2/baker/README.md index a21bcd5..339484a 100644 --- a/examples/fastspeech2/baker/README.md +++ b/examples/fastspeech2/baker/README.md @@ -9,7 +9,7 @@ Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/in ### Get MFA result of BZNSYP and Extract it. We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2. -You can download from here [baker_alignmenti_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignmenti_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo. ### Preprocess the dataset. @@ -26,9 +26,9 @@ Run the command below to preprocess the dataset. ``` ## Synthesize We use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder. -Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_1.0.zip) and unzip it. +Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_0.4.zip) and unzip it. ```bash -unzip parallel_wavegan_baker_ckpt_1.0.zip +unzip parallel_wavegan_baker_ckpt_0.4.zip ``` `synthesize.sh` can synthesize waveform from `metadata.jsonl`. `synthesize_e2e.sh` can synthesize waveform from text list. @@ -44,19 +44,19 @@ or You can see the bash files for more datails of input parameters. ## Pretrained Model -Pretrained Model with no sil in the edge of audios can be downloaded here. [fastspeech2_nosil_baker_ckpt_1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_1.0.zip) +Pretrained Model with no sil in the edge of audios can be downloaded here. [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip) Then, you can use the following scripts to synthesize for `sentences.txt` using pretrained fastspeech2 model. ```bash -python3 synthesize_e2e.py \↩ - --fastspeech2-config=fastspeech2_nosil_baker_ckpt_1.0/default.yaml \↩ - --fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_1.0/snapshot_iter_76000.pdz \↩ - --fastspeech2-stat=fastspeech2_nosil_baker_ckpt_1.0/speech_stats.npy \↩ - --pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \↩ - --pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \↩ - --pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \↩ - --text=sentences.txt \↩ - --output-dir=exp/debug/test_e2e \↩ - --device="gpu" \↩ - --phones=fastspeech2_nosil_baker_ckpt_1.0/phone_id_map.txt↩ +python3 synthesize_e2e.py \ + --fastspeech2-config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \ + --fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \ + --fastspeech2-stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \ + --pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \ + --pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \ + --pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \ + --text=sentences.txt \ + --output-dir=exp/debug/test_e2e \ + --device="gpu" \ + --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt ``` diff --git a/examples/fastspeech2/baker/batch_fn.py b/examples/fastspeech2/baker/batch_fn.py index 58cb6b6..1bbab84 100644 --- a/examples/fastspeech2/baker/batch_fn.py +++ b/examples/fastspeech2/baker/batch_fn.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import paddle from parakeet.data.batch import batch_sequences @@ -35,6 +36,15 @@ def collate_baker_examples(examples): durations = batch_sequences(durations) energy = batch_sequences(energy) + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + pitch = paddle.to_tensor(pitch) + speech = paddle.to_tensor(speech) + durations = paddle.to_tensor(durations) + energy = paddle.to_tensor(energy) + text_lengths = paddle.to_tensor(text_lengths) + speech_lengths = paddle.to_tensor(speech_lengths) + batch = { "text": text, "text_lengths": text_lengths, diff --git a/examples/fastspeech2/baker/frontend.py b/examples/fastspeech2/baker/frontend.py index d0cebd4..8d2c1f1 100644 --- a/examples/fastspeech2/baker/frontend.py +++ b/examples/fastspeech2/baker/frontend.py @@ -44,9 +44,7 @@ class Frontend(): def _t2id(self, tones): # replace unk phone with sp - tones = [ - tone if tone in self.vocab_tones else "0" for tone in tones - ] + tones = [tone if tone in self.vocab_tones else "0" for tone in tones] tone_ids = [self.vocab_tones[item] for item in tones] return np.array(tone_ids, np.int64) diff --git a/examples/fastspeech2/baker/gen_duration_from_textgrid.py b/examples/fastspeech2/baker/gen_duration_from_textgrid.py index a2179df..b3a39d3 100644 --- a/examples/fastspeech2/baker/gen_duration_from_textgrid.py +++ b/examples/fastspeech2/baker/gen_duration_from_textgrid.py @@ -41,10 +41,10 @@ def readtg(config, tg_path): durations[-2] += durations[-1] durations = durations[:-1] # replace the last sp with sil - phones[-1] = "sil" if phones[-1]=="sp" else phones[-1] + phones[-1] = "sil" if phones[-1] == "sp" else phones[-1] results = "" - + for (p, d) in zip(phones, durations): results += p + " " + str(d) + " " return results.strip() diff --git a/examples/fastspeech2/baker/get_feats.py b/examples/fastspeech2/baker/get_feats.py index bcd29cf..4e500e5 100644 --- a/examples/fastspeech2/baker/get_feats.py +++ b/examples/fastspeech2/baker/get_feats.py @@ -88,12 +88,7 @@ class LogMelFBank(): class Pitch(): - def __init__(self, - sr=24000, - hop_length=300, - f0min=80, - f0max=7600 - ): + def __init__(self, sr=24000, hop_length=300, f0min=80, f0max=7600): self.sr = sr self.hop_length = hop_length @@ -199,7 +194,7 @@ class Energy(): def _calculate_energy(self, input): input = input.astype(np.float32) input_stft = self._stft(input) - input_power = np.abs(input_stft) ** 2 + input_power = np.abs(input_stft)**2 energy = np.sqrt( np.clip( np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float( @@ -241,21 +236,20 @@ if __name__ == "__main__": print(mel) print(mel.shape) - pitch_extractor = Pitch(sr=C.fs, - hop_length=C.n_shift, - f0min=C.f0min, - f0max=C.f0max) + pitch_extractor = Pitch( + sr=C.fs, hop_length=C.n_shift, f0min=C.f0min, f0max=C.f0max) duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5" duration = np.array([int(x) for x in duration.split(" ")]) avg_f0 = pitch_extractor.get_pitch(wav, duration=duration) print(avg_f0) print(avg_f0.shape) - energy_extractor = Energy(sr=C.fs, - n_fft=C.n_fft, - hop_length=C.n_shift, - win_length=C.win_length, - window=C.window) + energy_extractor = Energy( + sr=C.fs, + n_fft=C.n_fft, + hop_length=C.n_shift, + win_length=C.win_length, + window=C.window) duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5" duration = np.array([int(x) for x in duration.split(" ")]) avg_energy = energy_extractor.get_energy(wav, duration=duration) diff --git a/examples/fastspeech2/baker/preprocess.py b/examples/fastspeech2/baker/preprocess.py index 0c4afff..c079715 100644 --- a/examples/fastspeech2/baker/preprocess.py +++ b/examples/fastspeech2/baker/preprocess.py @@ -139,15 +139,14 @@ def compare_duration_and_mel_length(sentences, utt, mel): sentences.pop(utt) -def process_sentence( - config: Dict[str, Any], - fp: Path, - sentences: Dict, - output_dir: Path, - mel_extractor=None, - pitch_extractor=None, - energy_extractor=None, - cut_sil: bool = True): +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + mel_extractor=None, + pitch_extractor=None, + energy_extractor=None, + cut_sil: bool=True): utt_id = fp.stem record = None if utt_id in sentences: @@ -160,7 +159,8 @@ def process_sentence( durations = sentences[utt_id][1] d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') # little imprecise than use *.TextGrid directly - times = librosa.frames_to_time(d_cumsum, sr=config.fs, hop_length=config.n_shift) + times = librosa.frames_to_time( + d_cumsum, sr=config.fs, hop_length=config.n_shift) if cut_sil: start = 0 end = d_cumsum[-1] @@ -222,8 +222,8 @@ def process_sentences(config, mel_extractor=None, pitch_extractor=None, energy_extractor=None, - nprocs: int = 1, - cut_sil: bool = True): + nprocs: int=1, + cut_sil: bool=True): if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): @@ -239,7 +239,8 @@ def process_sentences(config, for fp in fps: future = pool.submit(process_sentence, config, fp, sentences, output_dir, mel_extractor, - pitch_extractor, energy_extractor, cut_sil) + pitch_extractor, energy_extractor, + cut_sil) future.add_done_callback(lambda p: progress.update()) futures.append(future) @@ -289,7 +290,10 @@ def main(): return True if str.lower() == 'true' else False parser.add_argument( - "--cut-sil", type=str2bool, default=True, help="whether cut sil in the edge of audio") + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") args = parser.parse_args() C = get_cfg_default() @@ -336,15 +340,14 @@ def main(): n_mels=C.n_mels, fmin=C.fmin, fmax=C.fmax) - pitch_extractor = Pitch(sr=C.fs, - hop_length=C.n_shift, - f0min=C.f0min, - f0max=C.f0max) - energy_extractor = Energy(sr=C.fs, - n_fft=C.n_fft, - hop_length=C.n_shift, - win_length=C.win_length, - window=C.window) + pitch_extractor = Pitch( + sr=C.fs, hop_length=C.n_shift, f0min=C.f0min, f0max=C.f0max) + energy_extractor = Energy( + sr=C.fs, + n_fft=C.n_fft, + hop_length=C.n_shift, + win_length=C.win_length, + window=C.window) # process for the 3 sections diff --git a/examples/fastspeech2/baker/synthesize.sh b/examples/fastspeech2/baker/synthesize.sh index 5b2df5c..a912bb6 100755 --- a/examples/fastspeech2/baker/synthesize.sh +++ b/examples/fastspeech2/baker/synthesize.sh @@ -3,11 +3,11 @@ python3 synthesize.py \ --fastspeech2-config=conf/default.yaml \ - --fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_62577.pdz \ + --fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \ --fastspeech2-stat=dump/train/speech_stats.npy \ - --pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \ - --pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \ - --pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \ + --pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \ + --pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \ + --pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \ --test-metadata=dump/test/norm/metadata.jsonl \ --output-dir=exp/debug/test \ --device="gpu" \ diff --git a/examples/fastspeech2/baker/synthesize_e2e.sh b/examples/fastspeech2/baker/synthesize_e2e.sh index 9ef3a26..7e0fa7c 100755 --- a/examples/fastspeech2/baker/synthesize_e2e.sh +++ b/examples/fastspeech2/baker/synthesize_e2e.sh @@ -3,11 +3,11 @@ python3 synthesize_e2e.py \ --fastspeech2-config=conf/default.yaml \ - --fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_136017.pdz \ + --fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \ --fastspeech2-stat=dump/train/speech_stats.npy \ - --pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \ - --pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \ - --pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \ + --pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \ + --pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \ + --pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \ --text=sentences.txt \ --output-dir=exp/debug/test_e2e \ --device="gpu" \ diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py index 81c69fe..ff1540b 100644 --- a/examples/parallelwave_gan/baker/preprocess.py +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -97,7 +97,7 @@ def process_sentence(config: Dict[str, Any], utt_id = fp.stem # reading - y, sr = librosa.load(fp, sr=config.sr) # resampling may occur + y, sr = librosa.load(str(fp), sr=config.sr) # resampling may occur assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio." assert np.abs(y).max( ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." diff --git a/examples/parallelwave_gan/baker/run.sh b/examples/parallelwave_gan/baker/run.sh index 8cd1455..0128dfd 100644 --- a/examples/parallelwave_gan/baker/run.sh +++ b/examples/parallelwave_gan/baker/run.sh @@ -1,6 +1,5 @@ FLAGS_cudnn_exhaustive_search=true \ FLAGS_conv_workspace_size_limit=4000 \ - python train.py \ --train-metadata=dump/train/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \ diff --git a/parakeet/frontend/__init__.py b/parakeet/frontend/__init__.py index 3ab187a..b7b5874 100644 --- a/parakeet/frontend/__init__.py +++ b/parakeet/frontend/__init__.py @@ -19,4 +19,3 @@ from parakeet.frontend.normalizer import * from parakeet.frontend.cn_normalization import * from parakeet.frontend.tone_sandhi import * from parakeet.frontend.generate_lexicon import * - diff --git a/parakeet/frontend/cn_frontend.py b/parakeet/frontend/cn_frontend.py index a82b407..52624e0 100644 --- a/parakeet/frontend/cn_frontend.py +++ b/parakeet/frontend/cn_frontend.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import jieba.posseg as psg import numpy as np import paddle @@ -34,7 +33,8 @@ class Frontend(): self.g2p_model = g2p_model if self.g2p_model == "g2pM": self.g2pM_model = G2pM() - self.pinyin2phone = generate_lexicon(with_tone=True, with_erhua=False) + self.pinyin2phone = generate_lexicon( + with_tone=True, with_erhua=False) def _get_initials_finals(self, word): initials = [] @@ -55,7 +55,7 @@ class Frontend(): elif self.g2p_model == "g2pM": pinyins = self.g2pM_model(word, tone=True, char_split=False) for pinyin in pinyins: - pinyin = pinyin.replace("u:","v") + pinyin = pinyin.replace("u:", "v") if pinyin in self.pinyin2phone: initial_final_list = self.pinyin2phone[pinyin].split(" ") if len(initial_final_list) == 2: @@ -84,7 +84,8 @@ class Frontend(): if pos == 'eng': continue sub_initials, sub_finals = self._get_initials_finals(word) - sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals) + sub_finals = self.tone_modifier.modified_tone(word, pos, + sub_finals) initials.append(sub_initials) finals.append(sub_finals) # assert len(sub_initials) == len(sub_finals) == len(word) diff --git a/parakeet/frontend/cn_normalization/README.md b/parakeet/frontend/cn_normalization/README.md index 5ba5c29..fb6b87a 100644 --- a/parakeet/frontend/cn_normalization/README.md +++ b/parakeet/frontend/cn_normalization/README.md @@ -1,4 +1,4 @@ -supported NSW (Non-Standard-Word) Normalization +## Supported NSW (Non-Standard-Word) Normalization |NSW type|raw|normalized| |-|-|-| @@ -9,3 +9,5 @@ supported NSW (Non-Standard-Word) Normalization |money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五 三十四点五元 二十点一万| |percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨| |telephone|这是固话0421-33441122
这是手机+86 18544139121|这是固话零四二一三三四四一一二二
这是手机八六一八五四四一三九一二一| +## References +[Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files) diff --git a/parakeet/frontend/cn_normalization/chronology.py b/parakeet/frontend/cn_normalization/chronology.py index fb44bfb..157d4ca 100644 --- a/parakeet/frontend/cn_normalization/chronology.py +++ b/parakeet/frontend/cn_normalization/chronology.py @@ -24,17 +24,18 @@ def _time_num2str(num_string: str) -> str: result = DIGITS['0'] + result return result + # 时刻表达式 -RE_TIME = re.compile( - r'([0-1]?[0-9]|2[0-3])' - r':([0-5][0-9])' - r'(:([0-5][0-9]))?' -) +RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])' + r':([0-5][0-9])' + r'(:([0-5][0-9]))?') + + def replace_time(match: re.Match) -> str: hour = match.group(1) minute = match.group(2) second = match.group(4) - + result = f"{num2str(hour)}点" if minute.lstrip('0'): result += f"{_time_num2str(minute)}分" @@ -43,11 +44,11 @@ def replace_time(match: re.Match) -> str: return result -RE_DATE = re.compile( - r'(\d{4}|\d{2})年' - r'((0?[1-9]|1[0-2])月)?' - r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?' -) +RE_DATE = re.compile(r'(\d{4}|\d{2})年' + r'((0?[1-9]|1[0-2])月)?' + r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?') + + def replace_date(match: re.Match) -> str: year = match.group(1) month = match.group(3) @@ -61,10 +62,12 @@ def replace_date(match: re.Match) -> str: result += f"{verbalize_cardinal(day)}{match.group(9)}" return result + # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 RE_DATE2 = re.compile( - r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])' -) + r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])') + + def replace_date2(match: re.Match) -> str: year = match.group(1) month = match.group(3) @@ -76,4 +79,4 @@ def replace_date2(match: re.Match) -> str: result += f"{verbalize_cardinal(month)}月" if day: result += f"{verbalize_cardinal(day)}日" - return result \ No newline at end of file + return result diff --git a/parakeet/frontend/cn_normalization/constants.py b/parakeet/frontend/cn_normalization/constants.py index 6758891..d1ae42b 100644 --- a/parakeet/frontend/cn_normalization/constants.py +++ b/parakeet/frontend/cn_normalization/constants.py @@ -16,7 +16,6 @@ import re import string from pypinyin.constants import SUPPORT_UCS4 - # 全角半角转换 # 英文字符全角 -> 半角映射表 (num: 52) F2H_ASCII_LETTERS = { @@ -28,10 +27,7 @@ F2H_ASCII_LETTERS = { H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} # 数字字符全角 -> 半角映射表 (num: 10) -F2H_DIGITS = { - chr(ord(char) + 65248): char - for char in string.digits -} +F2H_DIGITS = {chr(ord(char) + 65248): char for char in string.digits} # 数字字符半角 -> 全角映射表 H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()} @@ -49,24 +45,21 @@ H2F_SPACE = {' ': '\u3000'} # 非"有拼音的汉字"的字符串,可用于NSW提取 if SUPPORT_UCS4: - RE_NSW = re.compile( - r'(?:[^' - r'\u3007' # 〇 - r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] - r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] - r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] - r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] - r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] - r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] - r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] - r'])+' - ) + RE_NSW = re.compile(r'(?:[^' + r'\u3007' # 〇 + r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] + r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] + r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] + r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] + r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] + r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] + r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] + r'])+') else: RE_NSW = re.compile( # pragma: no cover r'(?:[^' - r'\u3007' # 〇 - r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] - r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] - r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] - r'])+' - ) \ No newline at end of file + r'\u3007' # 〇 + r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] + r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] + r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] + r'])+') diff --git a/parakeet/frontend/cn_normalization/num.py b/parakeet/frontend/cn_normalization/num.py index f8f4205..459d871 100644 --- a/parakeet/frontend/cn_normalization/num.py +++ b/parakeet/frontend/cn_normalization/num.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Rules to verbalize numbers into Chinese characters. https://zh.wikipedia.org/wiki/中文数字#現代中文 @@ -21,7 +20,6 @@ import re from collections import OrderedDict from typing import List - DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')} UNITS = OrderedDict({ 1: '十', @@ -33,6 +31,8 @@ UNITS = OrderedDict({ # 分数表达式 RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') + + def replace_frac(match: re.Match) -> str: sign = match.group(1) nominator = match.group(2) @@ -42,10 +42,12 @@ def replace_frac(match: re.Match) -> str: denominator: str = num2str(denominator) result = f"{sign}{denominator}分之{nominator}" return result - + # 百分数表达式 RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%') + + def replace_percentage(match: re.Match) -> str: sign = match.group(1) percent = match.group(2) @@ -54,28 +56,28 @@ def replace_percentage(match: re.Match) -> str: result = f"{sign}百分之{percent}" return result + # 整数表达式 # 带负号或者不带负号的整数 12, -10 -RE_INTEGER = re.compile( - r'(-?)' - r'(\d+)' -) +RE_INTEGER = re.compile(r'(-?)' r'(\d+)') # 编号-无符号整形 # 00078 RE_DEFAULT_NUM = re.compile(r'\d{4}\d*') + + def replace_default_num(match: re.Match): number = match.group(0) return verbalize_digit(number) + # 数字表达式 # 1. 整数: -10, 10; # 2. 浮点数: 10.2, -0.3 # 3. 不带符号和整数部分的纯浮点数: .22, .38 -RE_NUMBER = re.compile( - r'(-?)((\d+)(\.\d+)?)' - r'|(\.(\d+))' -) +RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))') + + def replace_number(match: re.Match) -> str: sign = match.group(1) number = match.group(2) @@ -88,11 +90,12 @@ def replace_number(match: re.Match) -> str: result = f"{sign}{number}" return result + # 范围表达式 # 12-23, 12~23 -RE_RANGE = re.compile( - r'(\d+)[-~](\d+)' -) +RE_RANGE = re.compile(r'(\d+)[-~](\d+)') + + def replace_range(match: re.Match) -> str: first, second = match.group(1), match.group(2) first: str = num2str(first) @@ -111,26 +114,31 @@ def _get_value(value_string: str, use_zero: bool=True) -> List[str]: else: return [DIGITS[stripped]] else: - largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped)) + largest_unit = next( + power for power in reversed(UNITS.keys()) if power < len(stripped)) first_part = value_string[:-largest_unit] second_part = value_string[-largest_unit:] - return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part) + return _get_value(first_part) + [UNITS[largest_unit]] + _get_value( + second_part) + def verbalize_cardinal(value_string: str) -> str: if not value_string: return '' - + # 000 -> '零' , 0 -> '零' value_string = value_string.lstrip('0') if len(value_string) == 0: return DIGITS['0'] - + result_symbols = _get_value(value_string) # verbalized number starting with '一十*' is abbreviated as `十*` - if len(result_symbols) >= 2 and result_symbols[0] == DIGITS['1'] and result_symbols[1] == UNITS[1]: + if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[ + '1'] and result_symbols[1] == UNITS[1]: result_symbols = result_symbols[1:] return ''.join(result_symbols) + def verbalize_digit(value_string: str, alt_one=False) -> str: result_symbols = [DIGITS[digit] for digit in value_string] result = ''.join(result_symbols) @@ -138,6 +146,7 @@ def verbalize_digit(value_string: str, alt_one=False) -> str: result.replace("一", "幺") return result + def num2str(value_string: str) -> str: integer_decimal = value_string.split('.') if len(integer_decimal) == 1: @@ -146,8 +155,10 @@ def num2str(value_string: str) -> str: elif len(integer_decimal) == 2: integer, decimal = integer_decimal else: - raise ValueError(f"The value string: '${value_string}' has more than one point in it.") - + raise ValueError( + f"The value string: '${value_string}' has more than one point in it." + ) + result = verbalize_cardinal(integer) decimal = decimal.rstrip('0') diff --git a/parakeet/frontend/cn_normalization/phonecode.py b/parakeet/frontend/cn_normalization/phonecode.py index 072a69a..7539555 100644 --- a/parakeet/frontend/cn_normalization/phonecode.py +++ b/parakeet/frontend/cn_normalization/phonecode.py @@ -16,14 +16,13 @@ import re from .num import verbalize_digit - # 规范化固话/手机号码 # 手机 # http://www.jihaoba.com/news/show/13680 # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 # 联通:130、131、132、156、155、186、185、176 # 电信:133、153、189、180、181、177 -RE_MOBILE_PHONE= re.compile( +RE_MOBILE_PHONE = re.compile( r"(? str: if mobile: sp_parts = phone_string.strip('+').split() result = ''.join( - [verbalize_digit(part, alt_one=True) for part in sp_parts]) + [verbalize_digit( + part, alt_one=True) for part in sp_parts]) return result else: sil_parts = phone_string.split('-') result = ''.join( - [verbalize_digit(part, alt_one=True) for part in sil_parts]) + [verbalize_digit( + part, alt_one=True) for part in sil_parts]) return result def replace_phone(match: re.Match) -> str: - return phone2str(match.group(0)) \ No newline at end of file + return phone2str(match.group(0)) diff --git a/parakeet/frontend/cn_normalization/quantifier.py b/parakeet/frontend/cn_normalization/quantifier.py index 2adfdc9..0a4bcaf 100644 --- a/parakeet/frontend/cn_normalization/quantifier.py +++ b/parakeet/frontend/cn_normalization/quantifier.py @@ -16,12 +16,11 @@ import re from .num import num2str - # 温度表达式,温度会影响负号的读法 # -3°C 零下三度 -RE_TEMPERATURE = re.compile( - r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)' -) +RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)') + + def replace_temperature(match: re.Match) -> str: sign = match.group(1) temperature = match.group(2) @@ -30,4 +29,4 @@ def replace_temperature(match: re.Match) -> str: temperature: str = num2str(temperature) unit: str = "摄氏度" if unit == "摄氏度" else "度" result = f"{sign}{temperature}{unit}" - return result \ No newline at end of file + return result diff --git a/parakeet/frontend/tone_sandhi.py b/parakeet/frontend/tone_sandhi.py index c7dc41f..a03989c 100644 --- a/parakeet/frontend/tone_sandhi.py +++ b/parakeet/frontend/tone_sandhi.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple import jieba from pypinyin import lazy_pinyin @@ -20,43 +21,51 @@ from pypinyin import Style class ToneSandhi(): def __init__(self): - self.must_neural_tone_words = {'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', '难为', '队伍', - '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', '里头', '部分', '那么', '道士', - '造化', '迷糊', '连累', '这么', '这个', '运气', '过去', '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', - '财主', '豆腐', '讲究', '记性', '记号', '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', - '街坊', '行李', '行当', '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', - '芝麻', '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', '胡萝', - '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', '老头', '老太', '翻腾', - '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', '精神', '粮食', '簸箕', '篱笆', '算计', - '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', - '秀才', '福气', '祖宗', '砚台', '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', - '相声', '盘算', '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨', - '琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', '爱人', '热闹', - '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', '溜达', '温和', '清楚', '消息', - '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', - '架势', '枕头', '枇杷', '机灵', '本事', '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', - '新鲜', '故事', '收拾', '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', - '招呼', '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', '扁担', - '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', '念叨', '快活', '忙活', - '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', '干事', '帮手', '帐篷', '希罕', '师父', - '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', - '对付', '寡妇', '家伙', '客气', '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', - '娘家', '委屈', '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方', - '大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', '嘱咐', '嘟囔', - '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', '咳嗽', '和尚', '告诉', '告示', - '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', - '包涵', '匀称', '勤快', '动静', '动弹', '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', - '利害', '分析', '出息', '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', - '使唤', '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', '交情', - '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', '不由', '不在', '下水', - '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个'} + self.must_neural_tone_words = { + '麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', + '难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', + '里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去', + '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号', + '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当', + '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻', + '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', + '胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', + '老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', + '精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', + '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台', + '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算', + '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨', + '琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', + '爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', + '溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', + '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事', + '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾', + '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼', + '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', + '扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', + '念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', + '干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', + '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气', + '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈', + '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方', + '大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', + '嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', + '咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', + '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹', + '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息', + '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤', + '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', + '交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', + '不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个' + } # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041 # e.g. - # word: "家里" - # pos: "s" - # finals: ['ia1', 'i3'] - def _neural_sandhi(self, word, pos, finals): + # word: "家里" + # pos: "s" + # finals: ['ia1', 'i3'] + def _neural_sandhi(self, word: str, pos: str, + finals: List[str]) -> List[str]: ge_idx = word.find("个") if len(word) == 1 and word in "吧呢啊嘛" and pos == 'y': finals[-1] = finals[-1][:-1] + "5" @@ -80,12 +89,13 @@ class ToneSandhi(): elif len(word) >= 2 and word[-1] == word[-2] and pos[0] in {"n", "v"}: finals[-1] = finals[-1][:-1] + "5" # conventional tone5 in Chinese - elif word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words: + elif word in self.must_neural_tone_words or word[ + -2:] in self.must_neural_tone_words: finals[-1] = finals[-1][:-1] + "5" return finals - def _bu_sandhi(self, word, finals): + def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]: # "不" before tone4 should be bu2, e.g. 不怕 if len(word) > 1 and word[0] == "不" and finals[1][-1] == "4": finals[0] = finals[0][:-1] + "2" @@ -95,15 +105,16 @@ class ToneSandhi(): return finals - def _yi_sandhi(self, word, finals): + def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]: # "一" in number sequences, e.g. 一零零 - if len(word) > 1 and word[0] == "一" and all([item.isnumeric() for item in word]): + if len(word) > 1 and word[0] == "一" and all( + [item.isnumeric() for item in word]): return finals # "一" before tone4 should be yi2, e.g. 一段 elif len(word) > 1 and word[0] == "一" and finals[1][-1] == "4": finals[0] = finals[0][:-1] + "2" # "一" before non-tone4 should be yi4, e.g. 一天 - elif len(word) > 1 and word[0] == "一" and finals[1][-1]!= "4": + elif len(word) > 1 and word[0] == "一" and finals[1][-1] != "4": finals[0] = finals[0][:-1] + "4" # "一" between reduplication words shold be yi5, e.g. 看一看 elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]: @@ -113,7 +124,7 @@ class ToneSandhi(): finals[1] = finals[1][:-1] + "1" return finals - def _three_sandhi(self, word, finals): + def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: if len(word) == 2 and self._all_tone_three(finals): finals[0] = finals[0][:-1] + "2" elif len(word) == 3: @@ -138,7 +149,10 @@ class ToneSandhi(): elif len(new_word_list[0]) == 1: finals[1] = finals[1][:-1] + "2" else: - finals_list = [finals[:len(new_word_list[0])], finals[len(new_word_list[0]):]] + finals_list = [ + finals[:len(new_word_list[0])], + finals[len(new_word_list[0]):] + ] if len(finals_list) == 2: for i, sub in enumerate(finals_list): # e.g. 所有/人 @@ -161,12 +175,12 @@ class ToneSandhi(): return finals - def _all_tone_three(self, finals): + def _all_tone_three(self, finals: List[str]) -> bool: return all(x[-1] == "3" for x in finals) # merge "不" and the word behind it # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error - def _merge_bu(self, seg): + def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg = [] last_word = "" for word, pos in seg: @@ -185,17 +199,18 @@ class ToneSandhi(): # function 2: merge single "一" and the word behind it # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error # e.g. - # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')] - # output seg: [['听一听', 'v']] - def _merge_yi(self, seg): + # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')] + # output seg: [['听一听', 'v']] + def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg = [] # function 1 for i, (word, pos) in enumerate(seg): - if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][0] == seg[i + 1][0] and seg[i - 1][ - 1] == "v": + if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][ + 0] == seg[i + 1][0] and seg[i - 1][1] == "v": new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0] else: - if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][0] == word and pos == "v": + if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][ + 0] == word and pos == "v": continue else: new_seg.append([word, pos]) @@ -210,15 +225,20 @@ class ToneSandhi(): seg = new_seg return seg - def _merge_continuous_three_tones(self, seg): + def _merge_continuous_three_tones( + self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: new_seg = [] - sub_finals_list = [lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos) - in seg] + sub_finals_list = [ + lazy_pinyin( + word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + for (word, pos) in seg + ] assert len(sub_finals_list) == len(seg) merge_last = [False] * len(seg) for i, (word, pos) in enumerate(seg): - if i - 1 >= 0 and self._all_tone_three(sub_finals_list[i - 1]) and self._all_tone_three( - sub_finals_list[i]) and not merge_last[i - 1]: + if i - 1 >= 0 and self._all_tone_three(sub_finals_list[ + i - 1]) and self._all_tone_three(sub_finals_list[ + i]) and not merge_last[i - 1]: if len(seg[i - 1][0]) + len(seg[i][0]) <= 3: new_seg[-1][0] = new_seg[-1][0] + seg[i][0] merge_last[i] = True @@ -229,13 +249,15 @@ class ToneSandhi(): seg = new_seg return seg - def pre_merge_for_modify(self, seg): + def pre_merge_for_modify( + self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]: seg = self._merge_bu(seg) seg = self._merge_yi(seg) seg = self._merge_continuous_three_tones(seg) return seg - def modified_tone(self, word, pos, finals): + def modified_tone(self, word: str, pos: str, + finals: List[str]) -> List[str]: finals = self._bu_sandhi(word, finals) finals = self._yi_sandhi(word, finals) finals = self._neural_sandhi(word, pos, finals) diff --git a/parakeet/models/fastspeech2.py b/parakeet/models/fastspeech2.py index d13e64c..bff0b39 100644 --- a/parakeet/models/fastspeech2.py +++ b/parakeet/models/fastspeech2.py @@ -247,21 +247,20 @@ class FastSpeech2(nn.Layer): speech_lengths: paddle.Tensor, durations: paddle.Tensor, pitch: paddle.Tensor, - energy: paddle.Tensor, ) -> Tuple[paddle.Tensor, Dict[ - str, paddle.Tensor], paddle.Tensor]: + energy: paddle.Tensor, ) -> Sequence[paddle.Tensor]: """Calculate forward propagation. Parameters ---------- - text : LongTensor + text : Tensor Batch of padded token ids (B, Tmax). - text_lengths : LongTensor) + text_lengths : Tensor) Batch of lengths of each input (B,). speech : Tensor Batch of padded target features (B, Lmax, odim). - speech_lengths : LongTensor + speech_lengths : Tensor Batch of the lengths of each target (B,). - durations : LongTensor + durations : Tensor Batch of padded durations (B, Tmax). pitch : Tensor Batch of padded token-averaged pitch (B, Tmax, 1). @@ -281,8 +280,6 @@ class FastSpeech2(nn.Layer): energy predictor's output Tensor speech - Tensor - real text_lengths Tensor speech_lengths, modified if reduction_factor >1 """ @@ -387,17 +384,16 @@ class FastSpeech2(nn.Layer): pitch: paddle.Tensor=None, energy: paddle.Tensor=None, alpha: float=1.0, - use_teacher_forcing: bool=False, ) -> Tuple[ - paddle.Tensor, paddle.Tensor, paddle.Tensor]: + use_teacher_forcing: bool=False, ) -> paddle.Tensor: """Generate the sequence of features given the sequences of characters. Parameters ---------- - text : LongTensor + text : Tensor Input sequence of characters (T,). speech : Tensor, optional Feature sequence to extract style (N, idim). - durations : LongTensor, optional + durations : Tensor, optional Groundtruth of duration (T,). pitch : Tensor, optional Groundtruth of token-averaged pitch (T, 1). @@ -452,7 +448,7 @@ class FastSpeech2(nn.Layer): Parameters ---------- - ilens : LongTensor + ilens : Tensor Batch of lengths (B,). Returns @@ -553,7 +549,7 @@ class FastSpeech2Loss(nn.Layer): Batch of outputs after postnets (B, Lmax, odim). before_outs : Tensor Batch of outputs before postnets (B, Lmax, odim). - d_outs : LongTensor + d_outs : Tensor Batch of outputs of duration predictor (B, Tmax). p_outs : Tensor Batch of outputs of pitch predictor (B, Tmax, 1). @@ -561,15 +557,15 @@ class FastSpeech2Loss(nn.Layer): Batch of outputs of energy predictor (B, Tmax, 1). ys : Tensor Batch of target features (B, Lmax, odim). - ds : LongTensor + ds : Tensor Batch of durations (B, Tmax). ps : Tensor Batch of target token-averaged pitch (B, Tmax, 1). es : Tensor Batch of target token-averaged energy (B, Tmax, 1). - ilens : LongTensor + ilens : Tensor Batch of the lengths of each input (B,). - olens : LongTensor + olens : Tensor Batch of the lengths of each target (B,). Returns diff --git a/parakeet/models/waveflow.py b/parakeet/models/waveflow.py index 6744f28..e274cef 100644 --- a/parakeet/models/waveflow.py +++ b/parakeet/models/waveflow.py @@ -44,7 +44,7 @@ def fold(x, n_group): Tensor : [shape=(\*, time_steps // n_group, group)] Folded tensor. """ - spatial_shape = list(x.shape[:-1]) + spatial_shape = list(x.shape[:-1]) time_steps = paddle.shape(x)[-1] new_shape = spatial_shape + [time_steps // n_group, n_group] return paddle.reshape(x, new_shape) @@ -549,9 +549,9 @@ class Flow(nn.Layer): z_row = z[:, :, i:i + 1, :] condition_row = condition[:, :, i:i + 1, :] x_next_row, (logs, b) = self._inverse_row(z_row, x_row, - condition_row) - x[:, :, i:i+1, :] = x_next_row - + condition_row) + x[:, :, i:i + 1, :] = x_next_row + return x @@ -615,7 +615,7 @@ class WaveFlow(nn.LayerList): def _trim(self, x, condition): assert condition.shape[-1] >= x.shape[-1] - pruned_len = int(paddle.shape(x)[-1] // self.n_group * self.n_group) + pruned_len = int(paddle.shape(x)[-1] // self.n_group * self.n_group) if x.shape[-1] > pruned_len: x = x[:, :pruned_len] diff --git a/parakeet/utils/timeline.py b/parakeet/utils/timeline.py index 2a399b7..119a2e9 100644 --- a/parakeet/utils/timeline.py +++ b/parakeet/utils/timeline.py @@ -167,16 +167,16 @@ class Timeline(object): if (k, mevent.device_id, "GPU") not in self._mem_devices: pid = self._allocate_pid() self._mem_devices[(k, mevent.device_id, "GPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:gpu:%d" % (k, mevent.device_id), - pid) + self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" + % (k, mevent.device_id), + pid) elif mevent.place == profiler_pb2.MemEvent.CPUPlace: if (k, mevent.device_id, "CPU") not in self._mem_devices: pid = self._allocate_pid() self._mem_devices[(k, mevent.device_id, "CPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:cpu:%d" % (k, mevent.device_id), - pid) + self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" + % (k, mevent.device_id), + pid) elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace: if (k, mevent.device_id, "CUDAPinnedPlace" ) not in self._mem_devices: @@ -190,9 +190,9 @@ class Timeline(object): if (k, mevent.device_id, "NPU") not in self._mem_devices: pid = self._allocate_pid() self._mem_devices[(k, mevent.device_id, "NPU")] = pid - self._chrome_trace.emit_pid( - "memory usage on %s:npu:%d" % (k, mevent.device_id), - pid) + self._chrome_trace.emit_pid("memory usage on %s:npu:%d" + % (k, mevent.device_id), + pid) if (k, 0, "CPU") not in self._mem_devices: pid = self._allocate_pid() self._mem_devices[(k, 0, "CPU")] = pid @@ -273,14 +273,14 @@ class Timeline(object): total_size = 0 while i < len(mem_list): total_size += mem_list[i]['size'] - while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[ - i + 1]['time']: + while i < len(mem_list) - 1 and mem_list[i][ + 'time'] == mem_list[i + 1]['time']: total_size += mem_list[i + 1]['size'] i += 1 self._chrome_trace.emit_counter( - "Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'], - 0, total_size) + "Memory", "Memory", mem_list[i]['pid'], + mem_list[i]['time'], 0, total_size) i += 1 def generate_chrome_trace(self):