format code and add typehint for tone_sandhi
This commit is contained in:
parent
5e35a696e4
commit
a22b4dd171
|
@ -9,7 +9,7 @@ Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/in
|
|||
### Get MFA result of BZNSYP and Extract it.
|
||||
|
||||
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
|
||||
You can download from here [baker_alignmenti_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignmenti_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
|
||||
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
|
||||
|
||||
### Preprocess the dataset.
|
||||
|
||||
|
@ -26,9 +26,9 @@ Run the command below to preprocess the dataset.
|
|||
```
|
||||
## Synthesize
|
||||
We use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
|
||||
Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_1.0.zip) and unzip it.
|
||||
Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_0.4.zip) and unzip it.
|
||||
```bash
|
||||
unzip parallel_wavegan_baker_ckpt_1.0.zip
|
||||
unzip parallel_wavegan_baker_ckpt_0.4.zip
|
||||
```
|
||||
`synthesize.sh` can synthesize waveform from `metadata.jsonl`.
|
||||
`synthesize_e2e.sh` can synthesize waveform from text list.
|
||||
|
@ -44,19 +44,19 @@ or
|
|||
You can see the bash files for more datails of input parameters.
|
||||
|
||||
## Pretrained Model
|
||||
Pretrained Model with no sil in the edge of audios can be downloaded here. [fastspeech2_nosil_baker_ckpt_1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_1.0.zip)
|
||||
Pretrained Model with no sil in the edge of audios can be downloaded here. [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip)
|
||||
|
||||
Then, you can use the following scripts to synthesize for `sentences.txt` using pretrained fastspeech2 model.
|
||||
```bash
|
||||
python3 synthesize_e2e.py \↩
|
||||
--fastspeech2-config=fastspeech2_nosil_baker_ckpt_1.0/default.yaml \↩
|
||||
--fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_1.0/snapshot_iter_76000.pdz \↩
|
||||
--fastspeech2-stat=fastspeech2_nosil_baker_ckpt_1.0/speech_stats.npy \↩
|
||||
--pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \↩
|
||||
--pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \↩
|
||||
--pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \↩
|
||||
--text=sentences.txt \↩
|
||||
--output-dir=exp/debug/test_e2e \↩
|
||||
--device="gpu" \↩
|
||||
--phones=fastspeech2_nosil_baker_ckpt_1.0/phone_id_map.txt↩
|
||||
python3 synthesize_e2e.py \
|
||||
--fastspeech2-config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \
|
||||
--fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \
|
||||
--fastspeech2-stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \
|
||||
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--text=sentences.txt \
|
||||
--output-dir=exp/debug/test_e2e \
|
||||
--device="gpu" \
|
||||
--phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
|
||||
```
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from parakeet.data.batch import batch_sequences
|
||||
|
||||
|
||||
|
@ -35,6 +36,15 @@ def collate_baker_examples(examples):
|
|||
durations = batch_sequences(durations)
|
||||
energy = batch_sequences(energy)
|
||||
|
||||
# convert each batch to paddle.Tensor
|
||||
text = paddle.to_tensor(text)
|
||||
pitch = paddle.to_tensor(pitch)
|
||||
speech = paddle.to_tensor(speech)
|
||||
durations = paddle.to_tensor(durations)
|
||||
energy = paddle.to_tensor(energy)
|
||||
text_lengths = paddle.to_tensor(text_lengths)
|
||||
speech_lengths = paddle.to_tensor(speech_lengths)
|
||||
|
||||
batch = {
|
||||
"text": text,
|
||||
"text_lengths": text_lengths,
|
||||
|
|
|
@ -44,9 +44,7 @@ class Frontend():
|
|||
|
||||
def _t2id(self, tones):
|
||||
# replace unk phone with sp
|
||||
tones = [
|
||||
tone if tone in self.vocab_tones else "0" for tone in tones
|
||||
]
|
||||
tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
|
||||
tone_ids = [self.vocab_tones[item] for item in tones]
|
||||
return np.array(tone_ids, np.int64)
|
||||
|
||||
|
|
|
@ -41,10 +41,10 @@ def readtg(config, tg_path):
|
|||
durations[-2] += durations[-1]
|
||||
durations = durations[:-1]
|
||||
# replace the last sp with sil
|
||||
phones[-1] = "sil" if phones[-1]=="sp" else phones[-1]
|
||||
phones[-1] = "sil" if phones[-1] == "sp" else phones[-1]
|
||||
|
||||
results = ""
|
||||
|
||||
|
||||
for (p, d) in zip(phones, durations):
|
||||
results += p + " " + str(d) + " "
|
||||
return results.strip()
|
||||
|
|
|
@ -88,12 +88,7 @@ class LogMelFBank():
|
|||
|
||||
|
||||
class Pitch():
|
||||
def __init__(self,
|
||||
sr=24000,
|
||||
hop_length=300,
|
||||
f0min=80,
|
||||
f0max=7600
|
||||
):
|
||||
def __init__(self, sr=24000, hop_length=300, f0min=80, f0max=7600):
|
||||
|
||||
self.sr = sr
|
||||
self.hop_length = hop_length
|
||||
|
@ -199,7 +194,7 @@ class Energy():
|
|||
def _calculate_energy(self, input):
|
||||
input = input.astype(np.float32)
|
||||
input_stft = self._stft(input)
|
||||
input_power = np.abs(input_stft) ** 2
|
||||
input_power = np.abs(input_stft)**2
|
||||
energy = np.sqrt(
|
||||
np.clip(
|
||||
np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float(
|
||||
|
@ -241,21 +236,20 @@ if __name__ == "__main__":
|
|||
print(mel)
|
||||
print(mel.shape)
|
||||
|
||||
pitch_extractor = Pitch(sr=C.fs,
|
||||
hop_length=C.n_shift,
|
||||
f0min=C.f0min,
|
||||
f0max=C.f0max)
|
||||
pitch_extractor = Pitch(
|
||||
sr=C.fs, hop_length=C.n_shift, f0min=C.f0min, f0max=C.f0max)
|
||||
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
|
||||
duration = np.array([int(x) for x in duration.split(" ")])
|
||||
avg_f0 = pitch_extractor.get_pitch(wav, duration=duration)
|
||||
print(avg_f0)
|
||||
print(avg_f0.shape)
|
||||
|
||||
energy_extractor = Energy(sr=C.fs,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.n_shift,
|
||||
win_length=C.win_length,
|
||||
window=C.window)
|
||||
energy_extractor = Energy(
|
||||
sr=C.fs,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.n_shift,
|
||||
win_length=C.win_length,
|
||||
window=C.window)
|
||||
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
|
||||
duration = np.array([int(x) for x in duration.split(" ")])
|
||||
avg_energy = energy_extractor.get_energy(wav, duration=duration)
|
||||
|
|
|
@ -139,15 +139,14 @@ def compare_duration_and_mel_length(sentences, utt, mel):
|
|||
sentences.pop(utt)
|
||||
|
||||
|
||||
def process_sentence(
|
||||
config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
cut_sil: bool = True):
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
cut_sil: bool=True):
|
||||
utt_id = fp.stem
|
||||
record = None
|
||||
if utt_id in sentences:
|
||||
|
@ -160,7 +159,8 @@ def process_sentence(
|
|||
durations = sentences[utt_id][1]
|
||||
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||
# little imprecise than use *.TextGrid directly
|
||||
times = librosa.frames_to_time(d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||
times = librosa.frames_to_time(
|
||||
d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||
if cut_sil:
|
||||
start = 0
|
||||
end = d_cumsum[-1]
|
||||
|
@ -222,8 +222,8 @@ def process_sentences(config,
|
|||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
nprocs: int = 1,
|
||||
cut_sil: bool = True):
|
||||
nprocs: int=1,
|
||||
cut_sil: bool=True):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||
|
@ -239,7 +239,8 @@ def process_sentences(config,
|
|||
for fp in fps:
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
sentences, output_dir, mel_extractor,
|
||||
pitch_extractor, energy_extractor, cut_sil)
|
||||
pitch_extractor, energy_extractor,
|
||||
cut_sil)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
|
@ -289,7 +290,10 @@ def main():
|
|||
return True if str.lower() == 'true' else False
|
||||
|
||||
parser.add_argument(
|
||||
"--cut-sil", type=str2bool, default=True, help="whether cut sil in the edge of audio")
|
||||
"--cut-sil",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether cut sil in the edge of audio")
|
||||
args = parser.parse_args()
|
||||
|
||||
C = get_cfg_default()
|
||||
|
@ -336,15 +340,14 @@ def main():
|
|||
n_mels=C.n_mels,
|
||||
fmin=C.fmin,
|
||||
fmax=C.fmax)
|
||||
pitch_extractor = Pitch(sr=C.fs,
|
||||
hop_length=C.n_shift,
|
||||
f0min=C.f0min,
|
||||
f0max=C.f0max)
|
||||
energy_extractor = Energy(sr=C.fs,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.n_shift,
|
||||
win_length=C.win_length,
|
||||
window=C.window)
|
||||
pitch_extractor = Pitch(
|
||||
sr=C.fs, hop_length=C.n_shift, f0min=C.f0min, f0max=C.f0max)
|
||||
energy_extractor = Energy(
|
||||
sr=C.fs,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.n_shift,
|
||||
win_length=C.win_length,
|
||||
window=C.window)
|
||||
|
||||
# process for the 3 sections
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
|
||||
python3 synthesize.py \
|
||||
--fastspeech2-config=conf/default.yaml \
|
||||
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_62577.pdz \
|
||||
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \
|
||||
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||
--pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \
|
||||
--pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \
|
||||
--pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \
|
||||
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||
--output-dir=exp/debug/test \
|
||||
--device="gpu" \
|
||||
|
|
|
@ -3,11 +3,11 @@
|
|||
|
||||
python3 synthesize_e2e.py \
|
||||
--fastspeech2-config=conf/default.yaml \
|
||||
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_136017.pdz \
|
||||
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_153.pdz \
|
||||
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||
--pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \
|
||||
--pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \
|
||||
--pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \
|
||||
--pwg-config=parallel_wavegan_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--pwg-params=parallel_wavegan_baker_ckpt_0.4/pwg_generator.pdparams \
|
||||
--pwg-stat=parallel_wavegan_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--text=sentences.txt \
|
||||
--output-dir=exp/debug/test_e2e \
|
||||
--device="gpu" \
|
||||
|
|
|
@ -97,7 +97,7 @@ def process_sentence(config: Dict[str, Any],
|
|||
utt_id = fp.stem
|
||||
|
||||
# reading
|
||||
y, sr = librosa.load(fp, sr=config.sr) # resampling may occur
|
||||
y, sr = librosa.load(str(fp), sr=config.sr) # resampling may occur
|
||||
assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(y).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
FLAGS_cudnn_exhaustive_search=true \
|
||||
FLAGS_conv_workspace_size_limit=4000 \
|
||||
|
||||
python train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
|
|
|
@ -19,4 +19,3 @@ from parakeet.frontend.normalizer import *
|
|||
from parakeet.frontend.cn_normalization import *
|
||||
from parakeet.frontend.tone_sandhi import *
|
||||
from parakeet.frontend.generate_lexicon import *
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import jieba.posseg as psg
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
@ -34,7 +33,8 @@ class Frontend():
|
|||
self.g2p_model = g2p_model
|
||||
if self.g2p_model == "g2pM":
|
||||
self.g2pM_model = G2pM()
|
||||
self.pinyin2phone = generate_lexicon(with_tone=True, with_erhua=False)
|
||||
self.pinyin2phone = generate_lexicon(
|
||||
with_tone=True, with_erhua=False)
|
||||
|
||||
def _get_initials_finals(self, word):
|
||||
initials = []
|
||||
|
@ -55,7 +55,7 @@ class Frontend():
|
|||
elif self.g2p_model == "g2pM":
|
||||
pinyins = self.g2pM_model(word, tone=True, char_split=False)
|
||||
for pinyin in pinyins:
|
||||
pinyin = pinyin.replace("u:","v")
|
||||
pinyin = pinyin.replace("u:", "v")
|
||||
if pinyin in self.pinyin2phone:
|
||||
initial_final_list = self.pinyin2phone[pinyin].split(" ")
|
||||
if len(initial_final_list) == 2:
|
||||
|
@ -84,7 +84,8 @@ class Frontend():
|
|||
if pos == 'eng':
|
||||
continue
|
||||
sub_initials, sub_finals = self._get_initials_finals(word)
|
||||
sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals)
|
||||
sub_finals = self.tone_modifier.modified_tone(word, pos,
|
||||
sub_finals)
|
||||
initials.append(sub_initials)
|
||||
finals.append(sub_finals)
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
supported NSW (Non-Standard-Word) Normalization
|
||||
## Supported NSW (Non-Standard-Word) Normalization
|
||||
|
||||
|NSW type|raw|normalized|
|
||||
|-|-|-|
|
||||
|
@ -9,3 +9,5 @@ supported NSW (Non-Standard-Word) Normalization
|
|||
|money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五 三十四点五元 二十点一万|
|
||||
|percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨|
|
||||
|telephone|这是固话0421-33441122<br>这是手机+86 18544139121|这是固话零四二一三三四四一一二二<br>这是手机八六一八五四四一三九一二一|
|
||||
## References
|
||||
[Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files)
|
||||
|
|
|
@ -24,17 +24,18 @@ def _time_num2str(num_string: str) -> str:
|
|||
result = DIGITS['0'] + result
|
||||
return result
|
||||
|
||||
|
||||
# 时刻表达式
|
||||
RE_TIME = re.compile(
|
||||
r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?'
|
||||
)
|
||||
RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])'
|
||||
r':([0-5][0-9])'
|
||||
r'(:([0-5][0-9]))?')
|
||||
|
||||
|
||||
def replace_time(match: re.Match) -> str:
|
||||
hour = match.group(1)
|
||||
minute = match.group(2)
|
||||
second = match.group(4)
|
||||
|
||||
|
||||
result = f"{num2str(hour)}点"
|
||||
if minute.lstrip('0'):
|
||||
result += f"{_time_num2str(minute)}分"
|
||||
|
@ -43,11 +44,11 @@ def replace_time(match: re.Match) -> str:
|
|||
return result
|
||||
|
||||
|
||||
RE_DATE = re.compile(
|
||||
r'(\d{4}|\d{2})年'
|
||||
r'((0?[1-9]|1[0-2])月)?'
|
||||
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?'
|
||||
)
|
||||
RE_DATE = re.compile(r'(\d{4}|\d{2})年'
|
||||
r'((0?[1-9]|1[0-2])月)?'
|
||||
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?')
|
||||
|
||||
|
||||
def replace_date(match: re.Match) -> str:
|
||||
year = match.group(1)
|
||||
month = match.group(3)
|
||||
|
@ -61,10 +62,12 @@ def replace_date(match: re.Match) -> str:
|
|||
result += f"{verbalize_cardinal(day)}{match.group(9)}"
|
||||
return result
|
||||
|
||||
|
||||
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
||||
RE_DATE2 = re.compile(
|
||||
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])'
|
||||
)
|
||||
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])')
|
||||
|
||||
|
||||
def replace_date2(match: re.Match) -> str:
|
||||
year = match.group(1)
|
||||
month = match.group(3)
|
||||
|
@ -76,4 +79,4 @@ def replace_date2(match: re.Match) -> str:
|
|||
result += f"{verbalize_cardinal(month)}月"
|
||||
if day:
|
||||
result += f"{verbalize_cardinal(day)}日"
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -16,7 +16,6 @@ import re
|
|||
import string
|
||||
from pypinyin.constants import SUPPORT_UCS4
|
||||
|
||||
|
||||
# 全角半角转换
|
||||
# 英文字符全角 -> 半角映射表 (num: 52)
|
||||
F2H_ASCII_LETTERS = {
|
||||
|
@ -28,10 +27,7 @@ F2H_ASCII_LETTERS = {
|
|||
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
||||
|
||||
# 数字字符全角 -> 半角映射表 (num: 10)
|
||||
F2H_DIGITS = {
|
||||
chr(ord(char) + 65248): char
|
||||
for char in string.digits
|
||||
}
|
||||
F2H_DIGITS = {chr(ord(char) + 65248): char for char in string.digits}
|
||||
# 数字字符半角 -> 全角映射表
|
||||
H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
|
||||
|
||||
|
@ -49,24 +45,21 @@ H2F_SPACE = {' ': '\u3000'}
|
|||
|
||||
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
||||
if SUPPORT_UCS4:
|
||||
RE_NSW = re.compile(
|
||||
r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
||||
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
||||
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
||||
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
||||
r'])+'
|
||||
)
|
||||
RE_NSW = re.compile(r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
||||
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
||||
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
||||
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
||||
r'])+')
|
||||
else:
|
||||
RE_NSW = re.compile( # pragma: no cover
|
||||
r'(?:[^'
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'])+'
|
||||
)
|
||||
r'\u3007' # 〇
|
||||
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||
r'])+')
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Rules to verbalize numbers into Chinese characters.
|
||||
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
||||
|
@ -21,7 +20,6 @@ import re
|
|||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
|
||||
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
||||
UNITS = OrderedDict({
|
||||
1: '十',
|
||||
|
@ -33,6 +31,8 @@ UNITS = OrderedDict({
|
|||
|
||||
# 分数表达式
|
||||
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||
|
||||
|
||||
def replace_frac(match: re.Match) -> str:
|
||||
sign = match.group(1)
|
||||
nominator = match.group(2)
|
||||
|
@ -42,10 +42,12 @@ def replace_frac(match: re.Match) -> str:
|
|||
denominator: str = num2str(denominator)
|
||||
result = f"{sign}{denominator}分之{nominator}"
|
||||
return result
|
||||
|
||||
|
||||
|
||||
# 百分数表达式
|
||||
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
||||
|
||||
|
||||
def replace_percentage(match: re.Match) -> str:
|
||||
sign = match.group(1)
|
||||
percent = match.group(2)
|
||||
|
@ -54,28 +56,28 @@ def replace_percentage(match: re.Match) -> str:
|
|||
result = f"{sign}百分之{percent}"
|
||||
return result
|
||||
|
||||
|
||||
# 整数表达式
|
||||
# 带负号或者不带负号的整数 12, -10
|
||||
RE_INTEGER = re.compile(
|
||||
r'(-?)'
|
||||
r'(\d+)'
|
||||
)
|
||||
RE_INTEGER = re.compile(r'(-?)' r'(\d+)')
|
||||
|
||||
# 编号-无符号整形
|
||||
# 00078
|
||||
RE_DEFAULT_NUM = re.compile(r'\d{4}\d*')
|
||||
|
||||
|
||||
def replace_default_num(match: re.Match):
|
||||
number = match.group(0)
|
||||
return verbalize_digit(number)
|
||||
|
||||
|
||||
# 数字表达式
|
||||
# 1. 整数: -10, 10;
|
||||
# 2. 浮点数: 10.2, -0.3
|
||||
# 3. 不带符号和整数部分的纯浮点数: .22, .38
|
||||
RE_NUMBER = re.compile(
|
||||
r'(-?)((\d+)(\.\d+)?)'
|
||||
r'|(\.(\d+))'
|
||||
)
|
||||
RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
|
||||
|
||||
|
||||
def replace_number(match: re.Match) -> str:
|
||||
sign = match.group(1)
|
||||
number = match.group(2)
|
||||
|
@ -88,11 +90,12 @@ def replace_number(match: re.Match) -> str:
|
|||
result = f"{sign}{number}"
|
||||
return result
|
||||
|
||||
|
||||
# 范围表达式
|
||||
# 12-23, 12~23
|
||||
RE_RANGE = re.compile(
|
||||
r'(\d+)[-~](\d+)'
|
||||
)
|
||||
RE_RANGE = re.compile(r'(\d+)[-~](\d+)')
|
||||
|
||||
|
||||
def replace_range(match: re.Match) -> str:
|
||||
first, second = match.group(1), match.group(2)
|
||||
first: str = num2str(first)
|
||||
|
@ -111,26 +114,31 @@ def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
|||
else:
|
||||
return [DIGITS[stripped]]
|
||||
else:
|
||||
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
largest_unit = next(
|
||||
power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||
first_part = value_string[:-largest_unit]
|
||||
second_part = value_string[-largest_unit:]
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
|
||||
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
|
||||
second_part)
|
||||
|
||||
|
||||
def verbalize_cardinal(value_string: str) -> str:
|
||||
if not value_string:
|
||||
return ''
|
||||
|
||||
|
||||
# 000 -> '零' , 0 -> '零'
|
||||
value_string = value_string.lstrip('0')
|
||||
if len(value_string) == 0:
|
||||
return DIGITS['0']
|
||||
|
||||
|
||||
result_symbols = _get_value(value_string)
|
||||
# verbalized number starting with '一十*' is abbreviated as `十*`
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS['1'] and result_symbols[1] == UNITS[1]:
|
||||
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
|
||||
'1'] and result_symbols[1] == UNITS[1]:
|
||||
result_symbols = result_symbols[1:]
|
||||
return ''.join(result_symbols)
|
||||
|
||||
|
||||
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
||||
result_symbols = [DIGITS[digit] for digit in value_string]
|
||||
result = ''.join(result_symbols)
|
||||
|
@ -138,6 +146,7 @@ def verbalize_digit(value_string: str, alt_one=False) -> str:
|
|||
result.replace("一", "幺")
|
||||
return result
|
||||
|
||||
|
||||
def num2str(value_string: str) -> str:
|
||||
integer_decimal = value_string.split('.')
|
||||
if len(integer_decimal) == 1:
|
||||
|
@ -146,8 +155,10 @@ def num2str(value_string: str) -> str:
|
|||
elif len(integer_decimal) == 2:
|
||||
integer, decimal = integer_decimal
|
||||
else:
|
||||
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
|
||||
|
||||
raise ValueError(
|
||||
f"The value string: '${value_string}' has more than one point in it."
|
||||
)
|
||||
|
||||
result = verbalize_cardinal(integer)
|
||||
|
||||
decimal = decimal.rstrip('0')
|
||||
|
|
|
@ -16,14 +16,13 @@ import re
|
|||
|
||||
from .num import verbalize_digit
|
||||
|
||||
|
||||
# 规范化固话/手机号码
|
||||
# 手机
|
||||
# http://www.jihaoba.com/news/show/13680
|
||||
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
||||
# 联通:130、131、132、156、155、186、185、176
|
||||
# 电信:133、153、189、180、181、177
|
||||
RE_MOBILE_PHONE= re.compile(
|
||||
RE_MOBILE_PHONE = re.compile(
|
||||
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||
RE_TELEPHONE = re.compile(
|
||||
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||
|
@ -33,14 +32,16 @@ def phone2str(phone_string: str, mobile=True) -> str:
|
|||
if mobile:
|
||||
sp_parts = phone_string.strip('+').split()
|
||||
result = ''.join(
|
||||
[verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||
[verbalize_digit(
|
||||
part, alt_one=True) for part in sp_parts])
|
||||
return result
|
||||
else:
|
||||
sil_parts = phone_string.split('-')
|
||||
result = ''.join(
|
||||
[verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||
[verbalize_digit(
|
||||
part, alt_one=True) for part in sil_parts])
|
||||
return result
|
||||
|
||||
|
||||
def replace_phone(match: re.Match) -> str:
|
||||
return phone2str(match.group(0))
|
||||
return phone2str(match.group(0))
|
||||
|
|
|
@ -16,12 +16,11 @@ import re
|
|||
|
||||
from .num import num2str
|
||||
|
||||
|
||||
# 温度表达式,温度会影响负号的读法
|
||||
# -3°C 零下三度
|
||||
RE_TEMPERATURE = re.compile(
|
||||
r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)'
|
||||
)
|
||||
RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)')
|
||||
|
||||
|
||||
def replace_temperature(match: re.Match) -> str:
|
||||
sign = match.group(1)
|
||||
temperature = match.group(2)
|
||||
|
@ -30,4 +29,4 @@ def replace_temperature(match: re.Match) -> str:
|
|||
temperature: str = num2str(temperature)
|
||||
unit: str = "摄氏度" if unit == "摄氏度" else "度"
|
||||
result = f"{sign}{temperature}{unit}"
|
||||
return result
|
||||
return result
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import jieba
|
||||
from pypinyin import lazy_pinyin
|
||||
|
@ -20,43 +21,51 @@ from pypinyin import Style
|
|||
|
||||
class ToneSandhi():
|
||||
def __init__(self):
|
||||
self.must_neural_tone_words = {'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', '难为', '队伍',
|
||||
'阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', '里头', '部分', '那么', '道士',
|
||||
'造化', '迷糊', '连累', '这么', '这个', '运气', '过去', '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄',
|
||||
'财主', '豆腐', '讲究', '记性', '记号', '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门',
|
||||
'街坊', '行李', '行当', '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇',
|
||||
'芝麻', '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', '胡萝',
|
||||
'胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', '老头', '老太', '翻腾',
|
||||
'罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', '精神', '粮食', '簸箕', '篱笆', '算计',
|
||||
'算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气',
|
||||
'秀才', '福气', '祖宗', '砚台', '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛',
|
||||
'相声', '盘算', '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
|
||||
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', '爱人', '热闹',
|
||||
'烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', '溜达', '温和', '清楚', '消息',
|
||||
'浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火',
|
||||
'架势', '枕头', '枇杷', '机灵', '本事', '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候',
|
||||
'新鲜', '故事', '收拾', '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌',
|
||||
'招呼', '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', '扁担',
|
||||
'戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', '念叨', '快活', '忙活',
|
||||
'志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', '干事', '帮手', '帐篷', '希罕', '师父',
|
||||
'师傅', '巴结', '巴掌', '差事', '工夫', '岁数', '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头',
|
||||
'对付', '寡妇', '家伙', '客气', '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家',
|
||||
'娘家', '委屈', '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
|
||||
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', '嘱咐', '嘟囔',
|
||||
'嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', '咳嗽', '和尚', '告诉', '告示',
|
||||
'含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱',
|
||||
'包涵', '匀称', '勤快', '动静', '动弹', '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索',
|
||||
'利害', '分析', '出息', '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜',
|
||||
'使唤', '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', '交情',
|
||||
'云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', '不由', '不在', '下水',
|
||||
'下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个'}
|
||||
self.must_neural_tone_words = {
|
||||
'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
|
||||
'难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊',
|
||||
'里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去',
|
||||
'软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号',
|
||||
'认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当',
|
||||
'蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻',
|
||||
'舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂',
|
||||
'胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆',
|
||||
'老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂',
|
||||
'精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿',
|
||||
'窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台',
|
||||
'码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算',
|
||||
'白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
|
||||
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快',
|
||||
'爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜',
|
||||
'溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔',
|
||||
'棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事',
|
||||
'木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾',
|
||||
'收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼',
|
||||
'抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实',
|
||||
'扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头',
|
||||
'念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼',
|
||||
'干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数',
|
||||
'屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气',
|
||||
'实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈',
|
||||
'姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
|
||||
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴',
|
||||
'嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦',
|
||||
'咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝',
|
||||
'叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹',
|
||||
'功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息',
|
||||
'凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
|
||||
'佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
|
||||
'交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
|
||||
'不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个'
|
||||
}
|
||||
|
||||
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
|
||||
# e.g.
|
||||
# word: "家里"
|
||||
# pos: "s"
|
||||
# finals: ['ia1', 'i3']
|
||||
def _neural_sandhi(self, word, pos, finals):
|
||||
# word: "家里"
|
||||
# pos: "s"
|
||||
# finals: ['ia1', 'i3']
|
||||
def _neural_sandhi(self, word: str, pos: str,
|
||||
finals: List[str]) -> List[str]:
|
||||
ge_idx = word.find("个")
|
||||
if len(word) == 1 and word in "吧呢啊嘛" and pos == 'y':
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
|
@ -80,12 +89,13 @@ class ToneSandhi():
|
|||
elif len(word) >= 2 and word[-1] == word[-2] and pos[0] in {"n", "v"}:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
# conventional tone5 in Chinese
|
||||
elif word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||
elif word in self.must_neural_tone_words or word[
|
||||
-2:] in self.must_neural_tone_words:
|
||||
finals[-1] = finals[-1][:-1] + "5"
|
||||
|
||||
return finals
|
||||
|
||||
def _bu_sandhi(self, word, finals):
|
||||
def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||
# "不" before tone4 should be bu2, e.g. 不怕
|
||||
if len(word) > 1 and word[0] == "不" and finals[1][-1] == "4":
|
||||
finals[0] = finals[0][:-1] + "2"
|
||||
|
@ -95,15 +105,16 @@ class ToneSandhi():
|
|||
|
||||
return finals
|
||||
|
||||
def _yi_sandhi(self, word, finals):
|
||||
def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||
# "一" in number sequences, e.g. 一零零
|
||||
if len(word) > 1 and word[0] == "一" and all([item.isnumeric() for item in word]):
|
||||
if len(word) > 1 and word[0] == "一" and all(
|
||||
[item.isnumeric() for item in word]):
|
||||
return finals
|
||||
# "一" before tone4 should be yi2, e.g. 一段
|
||||
elif len(word) > 1 and word[0] == "一" and finals[1][-1] == "4":
|
||||
finals[0] = finals[0][:-1] + "2"
|
||||
# "一" before non-tone4 should be yi4, e.g. 一天
|
||||
elif len(word) > 1 and word[0] == "一" and finals[1][-1]!= "4":
|
||||
elif len(word) > 1 and word[0] == "一" and finals[1][-1] != "4":
|
||||
finals[0] = finals[0][:-1] + "4"
|
||||
# "一" between reduplication words shold be yi5, e.g. 看一看
|
||||
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
||||
|
@ -113,7 +124,7 @@ class ToneSandhi():
|
|||
finals[1] = finals[1][:-1] + "1"
|
||||
return finals
|
||||
|
||||
def _three_sandhi(self, word, finals):
|
||||
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
|
||||
if len(word) == 2 and self._all_tone_three(finals):
|
||||
finals[0] = finals[0][:-1] + "2"
|
||||
elif len(word) == 3:
|
||||
|
@ -138,7 +149,10 @@ class ToneSandhi():
|
|||
elif len(new_word_list[0]) == 1:
|
||||
finals[1] = finals[1][:-1] + "2"
|
||||
else:
|
||||
finals_list = [finals[:len(new_word_list[0])], finals[len(new_word_list[0]):]]
|
||||
finals_list = [
|
||||
finals[:len(new_word_list[0])],
|
||||
finals[len(new_word_list[0]):]
|
||||
]
|
||||
if len(finals_list) == 2:
|
||||
for i, sub in enumerate(finals_list):
|
||||
# e.g. 所有/人
|
||||
|
@ -161,12 +175,12 @@ class ToneSandhi():
|
|||
|
||||
return finals
|
||||
|
||||
def _all_tone_three(self, finals):
|
||||
def _all_tone_three(self, finals: List[str]) -> bool:
|
||||
return all(x[-1] == "3" for x in finals)
|
||||
|
||||
# merge "不" and the word behind it
|
||||
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
|
||||
def _merge_bu(self, seg):
|
||||
def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
last_word = ""
|
||||
for word, pos in seg:
|
||||
|
@ -185,17 +199,18 @@ class ToneSandhi():
|
|||
# function 2: merge single "一" and the word behind it
|
||||
# if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
|
||||
# e.g.
|
||||
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
|
||||
# output seg: [['听一听', 'v']]
|
||||
def _merge_yi(self, seg):
|
||||
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
|
||||
# output seg: [['听一听', 'v']]
|
||||
def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
# function 1
|
||||
for i, (word, pos) in enumerate(seg):
|
||||
if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][0] == seg[i + 1][0] and seg[i - 1][
|
||||
1] == "v":
|
||||
if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][
|
||||
0] == seg[i + 1][0] and seg[i - 1][1] == "v":
|
||||
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
|
||||
else:
|
||||
if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][0] == word and pos == "v":
|
||||
if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][
|
||||
0] == word and pos == "v":
|
||||
continue
|
||||
else:
|
||||
new_seg.append([word, pos])
|
||||
|
@ -210,15 +225,20 @@ class ToneSandhi():
|
|||
seg = new_seg
|
||||
return seg
|
||||
|
||||
def _merge_continuous_three_tones(self, seg):
|
||||
def _merge_continuous_three_tones(
|
||||
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
new_seg = []
|
||||
sub_finals_list = [lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos)
|
||||
in seg]
|
||||
sub_finals_list = [
|
||||
lazy_pinyin(
|
||||
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||
for (word, pos) in seg
|
||||
]
|
||||
assert len(sub_finals_list) == len(seg)
|
||||
merge_last = [False] * len(seg)
|
||||
for i, (word, pos) in enumerate(seg):
|
||||
if i - 1 >= 0 and self._all_tone_three(sub_finals_list[i - 1]) and self._all_tone_three(
|
||||
sub_finals_list[i]) and not merge_last[i - 1]:
|
||||
if i - 1 >= 0 and self._all_tone_three(sub_finals_list[
|
||||
i - 1]) and self._all_tone_three(sub_finals_list[
|
||||
i]) and not merge_last[i - 1]:
|
||||
if len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||
merge_last[i] = True
|
||||
|
@ -229,13 +249,15 @@ class ToneSandhi():
|
|||
seg = new_seg
|
||||
return seg
|
||||
|
||||
def pre_merge_for_modify(self, seg):
|
||||
def pre_merge_for_modify(
|
||||
self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
seg = self._merge_bu(seg)
|
||||
seg = self._merge_yi(seg)
|
||||
seg = self._merge_continuous_three_tones(seg)
|
||||
return seg
|
||||
|
||||
def modified_tone(self, word, pos, finals):
|
||||
def modified_tone(self, word: str, pos: str,
|
||||
finals: List[str]) -> List[str]:
|
||||
finals = self._bu_sandhi(word, finals)
|
||||
finals = self._yi_sandhi(word, finals)
|
||||
finals = self._neural_sandhi(word, pos, finals)
|
||||
|
|
|
@ -247,21 +247,20 @@ class FastSpeech2(nn.Layer):
|
|||
speech_lengths: paddle.Tensor,
|
||||
durations: paddle.Tensor,
|
||||
pitch: paddle.Tensor,
|
||||
energy: paddle.Tensor, ) -> Tuple[paddle.Tensor, Dict[
|
||||
str, paddle.Tensor], paddle.Tensor]:
|
||||
energy: paddle.Tensor, ) -> Sequence[paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : LongTensor
|
||||
text : Tensor
|
||||
Batch of padded token ids (B, Tmax).
|
||||
text_lengths : LongTensor)
|
||||
text_lengths : Tensor)
|
||||
Batch of lengths of each input (B,).
|
||||
speech : Tensor
|
||||
Batch of padded target features (B, Lmax, odim).
|
||||
speech_lengths : LongTensor
|
||||
speech_lengths : Tensor
|
||||
Batch of the lengths of each target (B,).
|
||||
durations : LongTensor
|
||||
durations : Tensor
|
||||
Batch of padded durations (B, Tmax).
|
||||
pitch : Tensor
|
||||
Batch of padded token-averaged pitch (B, Tmax, 1).
|
||||
|
@ -281,8 +280,6 @@ class FastSpeech2(nn.Layer):
|
|||
energy predictor's output
|
||||
Tensor
|
||||
speech
|
||||
Tensor
|
||||
real text_lengths
|
||||
Tensor
|
||||
speech_lengths, modified if reduction_factor >1
|
||||
"""
|
||||
|
@ -387,17 +384,16 @@ class FastSpeech2(nn.Layer):
|
|||
pitch: paddle.Tensor=None,
|
||||
energy: paddle.Tensor=None,
|
||||
alpha: float=1.0,
|
||||
use_teacher_forcing: bool=False, ) -> Tuple[
|
||||
paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
use_teacher_forcing: bool=False, ) -> paddle.Tensor:
|
||||
"""Generate the sequence of features given the sequences of characters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : LongTensor
|
||||
text : Tensor
|
||||
Input sequence of characters (T,).
|
||||
speech : Tensor, optional
|
||||
Feature sequence to extract style (N, idim).
|
||||
durations : LongTensor, optional
|
||||
durations : Tensor, optional
|
||||
Groundtruth of duration (T,).
|
||||
pitch : Tensor, optional
|
||||
Groundtruth of token-averaged pitch (T, 1).
|
||||
|
@ -452,7 +448,7 @@ class FastSpeech2(nn.Layer):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
ilens : LongTensor
|
||||
ilens : Tensor
|
||||
Batch of lengths (B,).
|
||||
|
||||
Returns
|
||||
|
@ -553,7 +549,7 @@ class FastSpeech2Loss(nn.Layer):
|
|||
Batch of outputs after postnets (B, Lmax, odim).
|
||||
before_outs : Tensor
|
||||
Batch of outputs before postnets (B, Lmax, odim).
|
||||
d_outs : LongTensor
|
||||
d_outs : Tensor
|
||||
Batch of outputs of duration predictor (B, Tmax).
|
||||
p_outs : Tensor
|
||||
Batch of outputs of pitch predictor (B, Tmax, 1).
|
||||
|
@ -561,15 +557,15 @@ class FastSpeech2Loss(nn.Layer):
|
|||
Batch of outputs of energy predictor (B, Tmax, 1).
|
||||
ys : Tensor
|
||||
Batch of target features (B, Lmax, odim).
|
||||
ds : LongTensor
|
||||
ds : Tensor
|
||||
Batch of durations (B, Tmax).
|
||||
ps : Tensor
|
||||
Batch of target token-averaged pitch (B, Tmax, 1).
|
||||
es : Tensor
|
||||
Batch of target token-averaged energy (B, Tmax, 1).
|
||||
ilens : LongTensor
|
||||
ilens : Tensor
|
||||
Batch of the lengths of each input (B,).
|
||||
olens : LongTensor
|
||||
olens : Tensor
|
||||
Batch of the lengths of each target (B,).
|
||||
|
||||
Returns
|
||||
|
|
|
@ -44,7 +44,7 @@ def fold(x, n_group):
|
|||
Tensor : [shape=(\*, time_steps // n_group, group)]
|
||||
Folded tensor.
|
||||
"""
|
||||
spatial_shape = list(x.shape[:-1])
|
||||
spatial_shape = list(x.shape[:-1])
|
||||
time_steps = paddle.shape(x)[-1]
|
||||
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
||||
return paddle.reshape(x, new_shape)
|
||||
|
@ -549,9 +549,9 @@ class Flow(nn.Layer):
|
|||
z_row = z[:, :, i:i + 1, :]
|
||||
condition_row = condition[:, :, i:i + 1, :]
|
||||
x_next_row, (logs, b) = self._inverse_row(z_row, x_row,
|
||||
condition_row)
|
||||
x[:, :, i:i+1, :] = x_next_row
|
||||
|
||||
condition_row)
|
||||
x[:, :, i:i + 1, :] = x_next_row
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -615,7 +615,7 @@ class WaveFlow(nn.LayerList):
|
|||
|
||||
def _trim(self, x, condition):
|
||||
assert condition.shape[-1] >= x.shape[-1]
|
||||
pruned_len = int(paddle.shape(x)[-1] // self.n_group * self.n_group)
|
||||
pruned_len = int(paddle.shape(x)[-1] // self.n_group * self.n_group)
|
||||
|
||||
if x.shape[-1] > pruned_len:
|
||||
x = x[:, :pruned_len]
|
||||
|
|
|
@ -167,16 +167,16 @@ class Timeline(object):
|
|||
if (k, mevent.device_id, "GPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id, "GPU")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:gpu:%d" % (k, mevent.device_id),
|
||||
pid)
|
||||
self._chrome_trace.emit_pid("memory usage on %s:gpu:%d"
|
||||
% (k, mevent.device_id),
|
||||
pid)
|
||||
elif mevent.place == profiler_pb2.MemEvent.CPUPlace:
|
||||
if (k, mevent.device_id, "CPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id, "CPU")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:cpu:%d" % (k, mevent.device_id),
|
||||
pid)
|
||||
self._chrome_trace.emit_pid("memory usage on %s:cpu:%d"
|
||||
% (k, mevent.device_id),
|
||||
pid)
|
||||
elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace:
|
||||
if (k, mevent.device_id, "CUDAPinnedPlace"
|
||||
) not in self._mem_devices:
|
||||
|
@ -190,9 +190,9 @@ class Timeline(object):
|
|||
if (k, mevent.device_id, "NPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id, "NPU")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:npu:%d" % (k, mevent.device_id),
|
||||
pid)
|
||||
self._chrome_trace.emit_pid("memory usage on %s:npu:%d"
|
||||
% (k, mevent.device_id),
|
||||
pid)
|
||||
if (k, 0, "CPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, 0, "CPU")] = pid
|
||||
|
@ -273,14 +273,14 @@ class Timeline(object):
|
|||
total_size = 0
|
||||
while i < len(mem_list):
|
||||
total_size += mem_list[i]['size']
|
||||
while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[
|
||||
i + 1]['time']:
|
||||
while i < len(mem_list) - 1 and mem_list[i][
|
||||
'time'] == mem_list[i + 1]['time']:
|
||||
total_size += mem_list[i + 1]['size']
|
||||
i += 1
|
||||
|
||||
self._chrome_trace.emit_counter(
|
||||
"Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'],
|
||||
0, total_size)
|
||||
"Memory", "Memory", mem_list[i]['pid'],
|
||||
mem_list[i]['time'], 0, total_size)
|
||||
i += 1
|
||||
|
||||
def generate_chrome_trace(self):
|
||||
|
|
Loading…
Reference in New Issue