modularize Chinese frontend
This commit is contained in:
parent
3d39385d5e
commit
6aeb56301f
|
@ -1,14 +1,15 @@
|
||||||
|
|
||||||
|
|
||||||
# FastSpeech2 with BZNSYP
|
# FastSpeech2 with BZNSYP
|
||||||
------
|
|
||||||
## Dataset
|
## Dataset
|
||||||
-----
|
|
||||||
### Download and Extract the datasaet.
|
### Download and Extract the datasaet.
|
||||||
Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/index/source).
|
Download BZNSYP from it's [Official Website](https://test.data-baker.com/data/index/source).
|
||||||
### Get MFA result of BZNSYP and Extract it.
|
### Get MFA result of BZNSYP and Extract it.
|
||||||
|
|
||||||
we use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
|
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
|
||||||
you can download from here, or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
|
You can download from here [baker_alignmenti_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignmenti_tone.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/use_mfa) of our repo.
|
||||||
|
|
||||||
### Preprocess the dataset.
|
### Preprocess the dataset.
|
||||||
|
|
||||||
|
@ -20,15 +21,18 @@ Run the command below to preprocess the dataset.
|
||||||
./preprocess.sh
|
./preprocess.sh
|
||||||
```
|
```
|
||||||
## Train the model
|
## Train the model
|
||||||
---
|
|
||||||
```bash
|
```bash
|
||||||
./run.sh
|
./run.sh
|
||||||
```
|
```
|
||||||
## Synthesize
|
## Synthesize
|
||||||
---
|
We use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
|
||||||
we use [parallel wavegan](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/parallelwave_gan/baker) as the neural vocoder.
|
Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/parallel_wavegan_baker_ckpt_1.0.zip) and unzip it.
|
||||||
|
```bash
|
||||||
|
unzip parallel_wavegan_baker_ckpt_1.0.zip
|
||||||
|
```
|
||||||
`synthesize.sh` can synthesize waveform for `metadata.jsonl`.
|
`synthesize.sh` can synthesize waveform for `metadata.jsonl`.
|
||||||
`synthesize_e2e.sh` can synthesize waveform for text list.
|
`synthesize_e2e.sh` can synthesize waveform for text list.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./synthesize.sh
|
./synthesize.sh
|
||||||
```
|
```
|
||||||
|
@ -37,6 +41,22 @@ or
|
||||||
./synthesize_e2e.sh
|
./synthesize_e2e.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
you can see the bash files for more datails of input parameter.
|
You can see the bash files for more datails of input parameters.
|
||||||
|
|
||||||
## Pretrained Model
|
## Pretrained Model
|
||||||
|
Pretrained Model with no sil in the edge of audios can be downloaded here. [fastspeech2_nosil_baker_ckpt_1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_1.0.zip)
|
||||||
|
|
||||||
|
Then, you can use the following scripts to synthesize for `sentences.txt` using pretrained fastspeech2 model.
|
||||||
|
```bash
|
||||||
|
python3 synthesize_e2e.py \↩
|
||||||
|
--fastspeech2-config=fastspeech2_nosil_baker_ckpt_1.0/default.yaml \↩
|
||||||
|
--fastspeech2-checkpoint=fastspeech2_nosil_baker_ckpt_1.0/snapshot_iter_76000.pdz \↩
|
||||||
|
--fastspeech2-stat=fastspeech2_nosil_baker_ckpt_1.0/speech_stats.npy \↩
|
||||||
|
--pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \↩
|
||||||
|
--pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \↩
|
||||||
|
--pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \↩
|
||||||
|
--text=sentences.txt \↩
|
||||||
|
--output-dir=exp/debug/test_e2e \↩
|
||||||
|
--device="gpu" \↩
|
||||||
|
--phones=fastspeech2_nosil_baker_ckpt_1.0/phone_id_map.txt↩
|
||||||
|
```
|
||||||
|
|
|
@ -15,53 +15,26 @@
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from pypinyin import lazy_pinyin, Style
|
from parakeet.frontend.cn_frontend import Frontend as cnFrontend
|
||||||
import jieba
|
|
||||||
|
|
||||||
|
|
||||||
class Frontend():
|
class Frontend():
|
||||||
def __init__(self, vocab_path):
|
def __init__(self, phone_vocab_path=None, tone_vocab_path=None):
|
||||||
|
self.frontend = cnFrontend()
|
||||||
self.voc_phones = {}
|
self.voc_phones = {}
|
||||||
with open(vocab_path, 'rt') as f:
|
self.voc_tones = {}
|
||||||
phn_id = [line.strip().split() for line in f.readlines()]
|
if phone_vocab_path:
|
||||||
for phn, id in phn_id:
|
with open(phone_vocab_path, 'rt') as f:
|
||||||
self.voc_phones[phn] = int(id)
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
for phn, id in phn_id:
|
||||||
|
self.voc_phones[phn] = int(id)
|
||||||
|
if tone_vocab_path:
|
||||||
|
with open(tone_vocab_path, 'rt') as f:
|
||||||
|
tone_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
for tone, id in tone_id:
|
||||||
|
self.voc_tones[tone] = int(id)
|
||||||
|
|
||||||
def segment(self, sentence):
|
def _p2id(self, phonemes):
|
||||||
segments = re.split(r'[:,;。?!]', sentence)
|
|
||||||
segments = [seg for seg in segments if len(seg)]
|
|
||||||
return segments
|
|
||||||
|
|
||||||
def g2p(self, sentence):
|
|
||||||
segments = self.segment(sentence)
|
|
||||||
phones = []
|
|
||||||
|
|
||||||
for seg in segments:
|
|
||||||
seg = jieba.lcut(seg)
|
|
||||||
initials = lazy_pinyin(
|
|
||||||
seg, neutral_tone_with_five=True, style=Style.INITIALS)
|
|
||||||
finals = lazy_pinyin(
|
|
||||||
seg, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
|
||||||
for c, v in zip(initials, finals):
|
|
||||||
# NOTE: post process for pypinyin outputs
|
|
||||||
# we discriminate i, ii and iii
|
|
||||||
if re.match(r'i\d', v):
|
|
||||||
if c in ['z', 'c', 's']:
|
|
||||||
v = re.sub('i', 'ii', v)
|
|
||||||
elif c in ['zh', 'ch', 'sh', 'r']:
|
|
||||||
v = re.sub('i', 'iii', v)
|
|
||||||
if c:
|
|
||||||
phones.append(c)
|
|
||||||
if v:
|
|
||||||
phones.append(v)
|
|
||||||
# add sp between sentence
|
|
||||||
phones.append('sp')
|
|
||||||
# replace last sp with <eos>
|
|
||||||
phones[-1] = '<eos>'
|
|
||||||
return phones
|
|
||||||
|
|
||||||
def p2id(self, phonemes):
|
|
||||||
# replace unk phone with sp
|
# replace unk phone with sp
|
||||||
phonemes = [
|
phonemes = [
|
||||||
phn if phn in self.voc_phones else "sp" for phn in phonemes
|
phn if phn in self.voc_phones else "sp" for phn in phonemes
|
||||||
|
@ -69,8 +42,35 @@ class Frontend():
|
||||||
phone_ids = [self.voc_phones[item] for item in phonemes]
|
phone_ids = [self.voc_phones[item] for item in phonemes]
|
||||||
return np.array(phone_ids, np.int64)
|
return np.array(phone_ids, np.int64)
|
||||||
|
|
||||||
def text_analysis(self, sentence):
|
def _t2id(self, tones):
|
||||||
phonemes = self.g2p(sentence)
|
# replace unk phone with sp
|
||||||
phone_ids = self.p2id(phonemes)
|
tones = [
|
||||||
|
tone if tone in self.voc_tones else "0" for tone in tones
|
||||||
|
]
|
||||||
|
tone_ids = [self.voc_tones[item] for item in tones]
|
||||||
|
return np.array(tone_ids, np.int64)
|
||||||
|
|
||||||
|
def get_input_ids(self, sentence, get_tone_ids=False):
|
||||||
|
phonemes = self.frontend.get_phonemes(sentence)
|
||||||
|
result = {}
|
||||||
|
phones = []
|
||||||
|
tones = []
|
||||||
|
if get_tone_ids and self.voc_tones:
|
||||||
|
for full_phone in phonemes:
|
||||||
|
# split tone from finals
|
||||||
|
match = re.match(r'^(\w+)([012345])$', full_phone)
|
||||||
|
if match:
|
||||||
|
phones.append(match.group(1))
|
||||||
|
tones.append(match.group(2))
|
||||||
|
else:
|
||||||
|
phones.append(full_phone)
|
||||||
|
tones.append('0')
|
||||||
|
tone_ids = self._t2id(tones)
|
||||||
|
tone_ids = paddle.to_tensor(tone_ids)
|
||||||
|
result["tone_ids"] = tone_ids
|
||||||
|
else:
|
||||||
|
phones = phonemes
|
||||||
|
phone_ids = self._p2id(phones)
|
||||||
phone_ids = paddle.to_tensor(phone_ids)
|
phone_ids = paddle.to_tensor(phone_ids)
|
||||||
return phone_ids
|
result["phone_ids"] = phone_ids
|
||||||
|
return result
|
||||||
|
|
|
@ -35,9 +35,17 @@ def readtg(config, tg_path):
|
||||||
ends, sr=config.fs, hop_length=config.n_shift)
|
ends, sr=config.fs, hop_length=config.n_shift)
|
||||||
durations = np.diff(frame_pos, prepend=0)
|
durations = np.diff(frame_pos, prepend=0)
|
||||||
assert len(durations) == len(phones)
|
assert len(durations) == len(phones)
|
||||||
|
# merge "" and sp in the end
|
||||||
|
if phones[-1] == "":
|
||||||
|
phones = phones[:-1]
|
||||||
|
durations[-2] += durations[-1]
|
||||||
|
durations = durations[:-1]
|
||||||
|
# replace the last sp with sil
|
||||||
|
phones[-1] = "sil" if phones[-1]=="sp" else phones[-1]
|
||||||
|
|
||||||
results = ""
|
results = ""
|
||||||
|
|
||||||
for (p, d) in zip(phones, durations):
|
for (p, d) in zip(phones, durations):
|
||||||
p = "sil" if p == "" else p
|
|
||||||
results += p + " " + str(d) + " "
|
results += p + " " + str(d) + " "
|
||||||
return results.strip()
|
return results.strip()
|
||||||
|
|
||||||
|
|
|
@ -75,28 +75,15 @@ def deal_silence(sentence):
|
||||||
new_phn.append(p)
|
new_phn.append(p)
|
||||||
new_dur.append(cur_dur[i])
|
new_dur.append(cur_dur[i])
|
||||||
|
|
||||||
# merge little sil in the begin
|
|
||||||
if new_phn[0] == 'sil' and new_dur[0] <= 14:
|
|
||||||
new_phn = new_phn[1:]
|
|
||||||
new_dur[1] += new_dur[0]
|
|
||||||
new_dur = new_dur[1:]
|
|
||||||
|
|
||||||
# replace the last sil with <eos> if exist
|
|
||||||
if new_phn[-1] == 'sil':
|
|
||||||
new_phn[-1] = '<eos>'
|
|
||||||
else:
|
|
||||||
new_phn.append('<eos>')
|
|
||||||
new_dur.append(0)
|
|
||||||
|
|
||||||
for i, (p, d) in enumerate(zip(new_phn, new_dur)):
|
for i, (p, d) in enumerate(zip(new_phn, new_dur)):
|
||||||
if p in {"sil", "sp"}:
|
if p in {"sp"}:
|
||||||
if d < 14:
|
if d < 14:
|
||||||
new_phn[i] = 'sp'
|
new_phn[i] = 'sp'
|
||||||
else:
|
else:
|
||||||
new_phn[i] = 'sp1'
|
new_phn[i] = 'spl'
|
||||||
|
|
||||||
assert len(new_phn) == len(new_dur)
|
assert len(new_phn) == len(new_dur)
|
||||||
sentence[utt] = (new_phn, new_dur)
|
sentence[utt] = [new_phn, new_dur]
|
||||||
|
|
||||||
|
|
||||||
def get_input_token(sentence, output_path):
|
def get_input_token(sentence, output_path):
|
||||||
|
@ -148,7 +135,6 @@ def compare_duration_and_mel_length(sentences, utt, mel):
|
||||||
elif sentences[utt][1][0] + len_diff > 0:
|
elif sentences[utt][1][0] + len_diff > 0:
|
||||||
sentences[utt][1][0] += len_diff
|
sentences[utt][1][0] += len_diff
|
||||||
else:
|
else:
|
||||||
# 一般不会触发这个
|
|
||||||
print("the len_diff is unable to correct:", len_diff)
|
print("the len_diff is unable to correct:", len_diff)
|
||||||
sentences.pop(utt)
|
sentences.pop(utt)
|
||||||
|
|
||||||
|
@ -160,7 +146,8 @@ def process_sentence(
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
mel_extractor=None,
|
mel_extractor=None,
|
||||||
pitch_extractor=None,
|
pitch_extractor=None,
|
||||||
energy_extractor=None, ):
|
energy_extractor=None,
|
||||||
|
cut_sil: bool = True):
|
||||||
utt_id = fp.stem
|
utt_id = fp.stem
|
||||||
record = None
|
record = None
|
||||||
if utt_id in sentences:
|
if utt_id in sentences:
|
||||||
|
@ -169,27 +156,47 @@ def process_sentence(
|
||||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||||
assert np.abs(wav).max(
|
assert np.abs(wav).max(
|
||||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||||
|
phones = sentences[utt_id][0]
|
||||||
|
durations = sentences[utt_id][1]
|
||||||
|
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||||
|
# little imprecise than use *.TextGrid directly
|
||||||
|
times = librosa.frames_to_time(d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||||
|
if cut_sil:
|
||||||
|
start = 0
|
||||||
|
end = d_cumsum[-1]
|
||||||
|
if phones[0] == "sil" and len(durations) > 1:
|
||||||
|
start = times[1]
|
||||||
|
durations = durations[1:]
|
||||||
|
phones = phones[1:]
|
||||||
|
if phones[-1] == 'sil' and len(durations) > 1:
|
||||||
|
end = times[-2]
|
||||||
|
durations = durations[:-1]
|
||||||
|
phones = phones[:-1]
|
||||||
|
sentences[utt_id][0] = phones
|
||||||
|
sentences[utt_id][1] = durations
|
||||||
|
start, end = librosa.time_to_samples([start, end], sr=config.fs)
|
||||||
|
wav = wav[start:end]
|
||||||
# extract mel feats
|
# extract mel feats
|
||||||
logmel = mel_extractor.get_log_mel_fbank(wav)
|
logmel = mel_extractor.get_log_mel_fbank(wav)
|
||||||
# change duration according to mel_length
|
# change duration according to mel_length
|
||||||
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
||||||
phones = sentences[utt_id][0]
|
phones = sentences[utt_id][0]
|
||||||
duration = sentences[utt_id][1]
|
durations = sentences[utt_id][1]
|
||||||
num_frames = logmel.shape[0]
|
num_frames = logmel.shape[0]
|
||||||
assert sum(duration) == num_frames
|
assert sum(durations) == num_frames
|
||||||
mel_dir = output_dir / "data_speech"
|
mel_dir = output_dir / "data_speech"
|
||||||
mel_dir.mkdir(parents=True, exist_ok=True)
|
mel_dir.mkdir(parents=True, exist_ok=True)
|
||||||
mel_path = mel_dir / (utt_id + "_speech.npy")
|
mel_path = mel_dir / (utt_id + "_speech.npy")
|
||||||
np.save(mel_path, logmel)
|
np.save(mel_path, logmel)
|
||||||
# extract pitch and energy
|
# extract pitch and energy
|
||||||
f0 = pitch_extractor.get_pitch(wav, duration=np.array(duration))
|
f0 = pitch_extractor.get_pitch(wav, duration=np.array(durations))
|
||||||
assert f0.shape[0] == len(duration)
|
assert f0.shape[0] == len(durations)
|
||||||
f0_dir = output_dir / "data_pitch"
|
f0_dir = output_dir / "data_pitch"
|
||||||
f0_dir.mkdir(parents=True, exist_ok=True)
|
f0_dir.mkdir(parents=True, exist_ok=True)
|
||||||
f0_path = f0_dir / (utt_id + "_pitch.npy")
|
f0_path = f0_dir / (utt_id + "_pitch.npy")
|
||||||
np.save(f0_path, f0)
|
np.save(f0_path, f0)
|
||||||
energy = energy_extractor.get_energy(wav, duration=np.array(duration))
|
energy = energy_extractor.get_energy(wav, duration=np.array(durations))
|
||||||
assert energy.shape[0] == len(duration)
|
assert energy.shape[0] == len(durations)
|
||||||
energy_dir = output_dir / "data_energy"
|
energy_dir = output_dir / "data_energy"
|
||||||
energy_dir.mkdir(parents=True, exist_ok=True)
|
energy_dir.mkdir(parents=True, exist_ok=True)
|
||||||
energy_path = energy_dir / (utt_id + "_energy.npy")
|
energy_path = energy_dir / (utt_id + "_energy.npy")
|
||||||
|
@ -199,7 +206,7 @@ def process_sentence(
|
||||||
"phones": phones,
|
"phones": phones,
|
||||||
"text_lengths": len(phones),
|
"text_lengths": len(phones),
|
||||||
"speech_lengths": num_frames,
|
"speech_lengths": num_frames,
|
||||||
"durations": duration,
|
"durations": durations,
|
||||||
# use absolute path
|
# use absolute path
|
||||||
"speech": str(mel_path.resolve()),
|
"speech": str(mel_path.resolve()),
|
||||||
"pitch": str(f0_path.resolve()),
|
"pitch": str(f0_path.resolve()),
|
||||||
|
@ -215,13 +222,14 @@ def process_sentences(config,
|
||||||
mel_extractor=None,
|
mel_extractor=None,
|
||||||
pitch_extractor=None,
|
pitch_extractor=None,
|
||||||
energy_extractor=None,
|
energy_extractor=None,
|
||||||
nprocs: int=1):
|
nprocs: int = 1,
|
||||||
|
cut_sil: bool = True):
|
||||||
if nprocs == 1:
|
if nprocs == 1:
|
||||||
results = []
|
results = []
|
||||||
for fp in tqdm.tqdm(fps, total=len(fps)):
|
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||||
record = process_sentence(config, fp, sentences, output_dir,
|
record = process_sentence(config, fp, sentences, output_dir,
|
||||||
mel_extractor, pitch_extractor,
|
mel_extractor, pitch_extractor,
|
||||||
energy_extractor)
|
energy_extractor, cut_sil)
|
||||||
if record:
|
if record:
|
||||||
results.append(record)
|
results.append(record)
|
||||||
else:
|
else:
|
||||||
|
@ -231,7 +239,7 @@ def process_sentences(config,
|
||||||
for fp in fps:
|
for fp in fps:
|
||||||
future = pool.submit(process_sentence, config, fp,
|
future = pool.submit(process_sentence, config, fp,
|
||||||
sentences, output_dir, mel_extractor,
|
sentences, output_dir, mel_extractor,
|
||||||
pitch_extractor, energy_extractor)
|
pitch_extractor, energy_extractor, cut_sil)
|
||||||
future.add_done_callback(lambda p: progress.update())
|
future.add_done_callback(lambda p: progress.update())
|
||||||
futures.append(future)
|
futures.append(future)
|
||||||
|
|
||||||
|
@ -276,6 +284,10 @@ def main():
|
||||||
help="logging level. higher is more logging. (default=1)")
|
help="logging level. higher is more logging. (default=1)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-cpu", type=int, default=1, help="number of process.")
|
"--num-cpu", type=int, default=1, help="number of process.")
|
||||||
|
def str2bool(str):
|
||||||
|
return True if str.lower() == 'true' else False
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut-sil", type=str2bool, default=True, help="whether cut sil in the edge of audio")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
C = get_cfg_default()
|
C = get_cfg_default()
|
||||||
|
@ -286,7 +298,7 @@ def main():
|
||||||
if args.verbose > 1:
|
if args.verbose > 1:
|
||||||
print(vars(args))
|
print(vars(args))
|
||||||
print(C)
|
print(C)
|
||||||
|
|
||||||
root_dir = Path(args.rootdir).expanduser()
|
root_dir = Path(args.rootdir).expanduser()
|
||||||
dumpdir = Path(args.dumpdir).expanduser()
|
dumpdir = Path(args.dumpdir).expanduser()
|
||||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -318,6 +330,7 @@ def main():
|
||||||
energy_extractor = Energy(C)
|
energy_extractor = Energy(C)
|
||||||
|
|
||||||
# process for the 3 sections
|
# process for the 3 sections
|
||||||
|
|
||||||
process_sentences(
|
process_sentences(
|
||||||
C,
|
C,
|
||||||
train_wav_files,
|
train_wav_files,
|
||||||
|
@ -326,7 +339,8 @@ def main():
|
||||||
mel_extractor,
|
mel_extractor,
|
||||||
pitch_extractor,
|
pitch_extractor,
|
||||||
energy_extractor,
|
energy_extractor,
|
||||||
nprocs=args.num_cpu)
|
nprocs=args.num_cpu,
|
||||||
|
cut_sil=args.cut_sil)
|
||||||
process_sentences(
|
process_sentences(
|
||||||
C,
|
C,
|
||||||
dev_wav_files,
|
dev_wav_files,
|
||||||
|
@ -335,7 +349,8 @@ def main():
|
||||||
mel_extractor,
|
mel_extractor,
|
||||||
pitch_extractor,
|
pitch_extractor,
|
||||||
energy_extractor,
|
energy_extractor,
|
||||||
nprocs=args.num_cpu)
|
cut_sil=args.cut_sil)
|
||||||
|
|
||||||
process_sentences(
|
process_sentences(
|
||||||
C,
|
C,
|
||||||
test_wav_files,
|
test_wav_files,
|
||||||
|
@ -344,7 +359,8 @@ def main():
|
||||||
mel_extractor,
|
mel_extractor,
|
||||||
pitch_extractor,
|
pitch_extractor,
|
||||||
energy_extractor,
|
energy_extractor,
|
||||||
nprocs=args.num_cpu)
|
nprocs=args.num_cpu,
|
||||||
|
cut_sil=args.cut_sil)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
|
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
|
||||||
|
|
||||||
# extract features
|
# extract features
|
||||||
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-path durations.txt --num-cpu 4
|
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur-path durations.txt --num-cpu 4 --cut-sil True
|
||||||
|
|
||||||
# # get features' stats(mean and std)
|
# # get features' stats(mean and std)
|
||||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
|
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
|
||||||
|
|
|
@ -5,9 +5,9 @@ python3 synthesize.py \
|
||||||
--fastspeech2-config=conf/default.yaml \
|
--fastspeech2-config=conf/default.yaml \
|
||||||
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_62577.pdz \
|
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_62577.pdz \
|
||||||
--fastspeech2-stat=dump/train/speech_stats.npy \
|
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||||
--pwg-config=pwg_default.yaml \
|
--pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \
|
||||||
--pwg-params=pwg_generator.pdparams \
|
--pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \
|
||||||
--pwg-stat=pwg_stats.npy \
|
--pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \
|
||||||
--test-metadata=dump/test/norm/metadata.jsonl \
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
--output-dir=exp/debug/test \
|
--output-dir=exp/debug/test \
|
||||||
--device="gpu" \
|
--device="gpu" \
|
||||||
|
|
|
@ -80,9 +80,11 @@ def evaluate(args, fastspeech2_config, pwg_config):
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for utt_id, sentence in sentences:
|
for utt_id, sentence in sentences:
|
||||||
phone_ids = frontend.text_analysis(sentence)
|
input_ids = frontend.get_input_ids(sentence)
|
||||||
|
phone_ids = input_ids["phone_ids"]
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
wav = pwg_inference(fastspeech2_inferencce(phone_ids))
|
mel = fastspeech2_inferencce(phone_ids)
|
||||||
|
wav = pwg_inference(mel)
|
||||||
sf.write(
|
sf.write(
|
||||||
str(output_dir / (utt_id + ".wav")),
|
str(output_dir / (utt_id + ".wav")),
|
||||||
wav.numpy(),
|
wav.numpy(),
|
||||||
|
|
|
@ -5,9 +5,9 @@ python3 synthesize_e2e.py \
|
||||||
--fastspeech2-config=conf/default.yaml \
|
--fastspeech2-config=conf/default.yaml \
|
||||||
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_136017.pdz \
|
--fastspeech2-checkpoint=exp/default/checkpoints/snapshot_iter_136017.pdz \
|
||||||
--fastspeech2-stat=dump/train/speech_stats.npy \
|
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||||
--pwg-config=pwg_default.yaml \
|
--pwg-config=parallel_wavegan_baker_ckpt_1.0/pwg_default.yaml \
|
||||||
--pwg-params=pwg_generator.pdparams \
|
--pwg-params=parallel_wavegan_baker_ckpt_1.0/pwg_generator.pdparams \
|
||||||
--pwg-stat=pwg_stats.npy \
|
--pwg-stat=parallel_wavegan_baker_ckpt_1.0/pwg_stats.npy \
|
||||||
--text=sentences.txt \
|
--text=sentences.txt \
|
||||||
--output-dir=exp/debug/test_e2e \
|
--output-dir=exp/debug/test_e2e \
|
||||||
--device="gpu" \
|
--device="gpu" \
|
||||||
|
|
|
@ -16,3 +16,7 @@ from parakeet.frontend.vocab import *
|
||||||
from parakeet.frontend.phonectic import *
|
from parakeet.frontend.phonectic import *
|
||||||
from parakeet.frontend.punctuation import *
|
from parakeet.frontend.punctuation import *
|
||||||
from parakeet.frontend.normalizer import *
|
from parakeet.frontend.normalizer import *
|
||||||
|
from parakeet.frontend.cn_normalization import *
|
||||||
|
from parakeet.frontend.modified_tone import *
|
||||||
|
from parakeet.frontend.generate_lexicon import *
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import jieba.posseg as psg
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import re
|
||||||
|
from g2pM import G2pM
|
||||||
|
from parakeet.frontend.modified_tone import ModifiedTone
|
||||||
|
from parakeet.frontend.cn_normalization.normalization import Normalizer
|
||||||
|
from pypinyin import lazy_pinyin, Style
|
||||||
|
|
||||||
|
from parakeet.frontend.generate_lexicon import generate_lexicon
|
||||||
|
|
||||||
|
|
||||||
|
class Frontend():
|
||||||
|
def __init__(self, g2p_model="pypinyin"):
|
||||||
|
self.tone_modifier = ModifiedTone()
|
||||||
|
self.normalizer = Normalizer()
|
||||||
|
self.punc = ":,;。?!“”‘’':,;.?!"
|
||||||
|
# g2p_model can be pypinyin and g2pM
|
||||||
|
self.g2p_model = g2p_model
|
||||||
|
if self.g2p_model == "g2pM":
|
||||||
|
self.g2pM_model = G2pM()
|
||||||
|
self.pinyin2phone = generate_lexicon(with_tone=True, with_erhua=False)
|
||||||
|
|
||||||
|
def _get_initials_finals(self, word):
|
||||||
|
initials = []
|
||||||
|
finals = []
|
||||||
|
if self.g2p_model == "pypinyin":
|
||||||
|
orig_initials = lazy_pinyin(
|
||||||
|
word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||||||
|
orig_finals = lazy_pinyin(
|
||||||
|
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||||||
|
for c, v in zip(orig_initials, orig_finals):
|
||||||
|
if re.match(r'i\d', v):
|
||||||
|
if c in ['z', 'c', 's']:
|
||||||
|
v = re.sub('i', 'ii', v)
|
||||||
|
elif c in ['zh', 'ch', 'sh', 'r']:
|
||||||
|
v = re.sub('i', 'iii', v)
|
||||||
|
initials.append(c)
|
||||||
|
finals.append(v)
|
||||||
|
elif self.g2p_model == "g2pM":
|
||||||
|
pinyins = self.g2pM_model(word, tone=True, char_split=False)
|
||||||
|
for pinyin in pinyins:
|
||||||
|
pinyin = pinyin.replace("u:","v")
|
||||||
|
if pinyin in self.pinyin2phone:
|
||||||
|
initial_final_list = self.pinyin2phone[pinyin].split(" ")
|
||||||
|
if len(initial_final_list) == 2:
|
||||||
|
initials.append(initial_final_list[0])
|
||||||
|
finals.append(initial_final_list[1])
|
||||||
|
elif len(initial_final_list) == 1:
|
||||||
|
initials.append('')
|
||||||
|
finals.append(initial_final_list[1])
|
||||||
|
else:
|
||||||
|
initials.append(pinyin)
|
||||||
|
finals.append(pinyin)
|
||||||
|
return initials, finals
|
||||||
|
|
||||||
|
# if merge_sentences, merge all sentences into one phone sequence
|
||||||
|
def _g2p(self, sentences, merge_sentences=True):
|
||||||
|
segments = sentences
|
||||||
|
phones_list = []
|
||||||
|
for seg in segments:
|
||||||
|
phones = []
|
||||||
|
seg = psg.lcut(seg)
|
||||||
|
initials = []
|
||||||
|
finals = []
|
||||||
|
seg = self.tone_modifier.pre_merge_for_modify(seg)
|
||||||
|
for word, pos in seg:
|
||||||
|
if pos == 'eng':
|
||||||
|
continue
|
||||||
|
sub_initials, sub_finals = self._get_initials_finals(word)
|
||||||
|
sub_finals = self.tone_modifier.modified_tone(word, pos, sub_finals)
|
||||||
|
initials.append(sub_initials)
|
||||||
|
finals.append(sub_finals)
|
||||||
|
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||||
|
initials = sum(initials, [])
|
||||||
|
finals = sum(finals, [])
|
||||||
|
for c, v in zip(initials, finals):
|
||||||
|
# NOTE: post process for pypinyin outputs
|
||||||
|
# we discriminate i, ii and iii
|
||||||
|
if c and c not in self.punc:
|
||||||
|
phones.append(c)
|
||||||
|
if v and v not in self.punc:
|
||||||
|
phones.append(v)
|
||||||
|
# add sp between sentence
|
||||||
|
if initials[-1] in self.punc:
|
||||||
|
phones.append('sp')
|
||||||
|
phones_list.append(phones)
|
||||||
|
if merge_sentences:
|
||||||
|
phones_list = sum(phones_list, [])
|
||||||
|
return phones_list
|
||||||
|
|
||||||
|
def get_phonemes(self, sentence):
|
||||||
|
sentences = self.normalizer.normalize(sentence)
|
||||||
|
phonemes = self._g2p(sentences)
|
||||||
|
return phonemes
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from parakeet.frontend.cn_normalization.normalization import *
|
|
@ -0,0 +1,79 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .num import verbalize_cardinal, verbalize_digit, num2str, DIGITS
|
||||||
|
|
||||||
|
|
||||||
|
def _time_num2str(num_string: str) -> str:
|
||||||
|
"""A special case for verbalizing number in time."""
|
||||||
|
result = num2str(num_string.lstrip('0'))
|
||||||
|
if num_string.startswith('0'):
|
||||||
|
result = DIGITS['0'] + result
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 时刻表达式
|
||||||
|
RE_TIME = re.compile(
|
||||||
|
r'([0-1]?[0-9]|2[0-3])'
|
||||||
|
r':([0-5][0-9])'
|
||||||
|
r'(:([0-5][0-9]))?'
|
||||||
|
)
|
||||||
|
def replace_time(match: re.Match) -> str:
|
||||||
|
hour = match.group(1)
|
||||||
|
minute = match.group(2)
|
||||||
|
second = match.group(4)
|
||||||
|
|
||||||
|
result = f"{num2str(hour)}点"
|
||||||
|
if minute.lstrip('0'):
|
||||||
|
result += f"{_time_num2str(minute)}分"
|
||||||
|
if second and second.lstrip('0'):
|
||||||
|
result += f"{_time_num2str(second)}秒"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
RE_DATE = re.compile(
|
||||||
|
r'(\d{4}|\d{2})年'
|
||||||
|
r'((0?[1-9]|1[0-2])月)?'
|
||||||
|
r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?'
|
||||||
|
)
|
||||||
|
def replace_date(match: re.Match) -> str:
|
||||||
|
year = match.group(1)
|
||||||
|
month = match.group(3)
|
||||||
|
day = match.group(5)
|
||||||
|
result = ""
|
||||||
|
if year:
|
||||||
|
result += f"{verbalize_digit(year)}年"
|
||||||
|
if month:
|
||||||
|
result += f"{verbalize_cardinal(month)}月"
|
||||||
|
if day:
|
||||||
|
result += f"{verbalize_cardinal(day)}{match.group(9)}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期
|
||||||
|
RE_DATE2 = re.compile(
|
||||||
|
r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])'
|
||||||
|
)
|
||||||
|
def replace_date2(match: re.Match) -> str:
|
||||||
|
year = match.group(1)
|
||||||
|
month = match.group(3)
|
||||||
|
day = match.group(4)
|
||||||
|
result = ""
|
||||||
|
if year:
|
||||||
|
result += f"{verbalize_digit(year)}年"
|
||||||
|
if month:
|
||||||
|
result += f"{verbalize_cardinal(month)}月"
|
||||||
|
if day:
|
||||||
|
result += f"{verbalize_cardinal(day)}日"
|
||||||
|
return result
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
from pypinyin.constants import SUPPORT_UCS4
|
||||||
|
|
||||||
|
|
||||||
|
# 全角半角转换
|
||||||
|
# 英文字符全角 -> 半角映射表 (num: 52)
|
||||||
|
F2H_ASCII_LETTERS = {
|
||||||
|
chr(ord(char) + 65248): char
|
||||||
|
for char in string.ascii_letters
|
||||||
|
}
|
||||||
|
|
||||||
|
# 英文字符半角 -> 全角映射表
|
||||||
|
H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()}
|
||||||
|
|
||||||
|
# 数字字符全角 -> 半角映射表 (num: 10)
|
||||||
|
F2H_DIGITS = {
|
||||||
|
chr(ord(char) + 65248): char
|
||||||
|
for char in string.digits
|
||||||
|
}
|
||||||
|
# 数字字符半角 -> 全角映射表
|
||||||
|
H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()}
|
||||||
|
|
||||||
|
# 标点符号全角 -> 半角映射表 (num: 32)
|
||||||
|
F2H_PUNCTUATIONS = {
|
||||||
|
chr(ord(char) + 65248): char
|
||||||
|
for char in string.punctuation
|
||||||
|
}
|
||||||
|
# 标点符号半角 -> 全角映射表
|
||||||
|
H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()}
|
||||||
|
|
||||||
|
# 空格 (num: 1)
|
||||||
|
F2H_SPACE = {'\u3000': ' '}
|
||||||
|
H2F_SPACE = {' ': '\u3000'}
|
||||||
|
|
||||||
|
# 非"有拼音的汉字"的字符串,可用于NSW提取
|
||||||
|
if SUPPORT_UCS4:
|
||||||
|
RE_NSW = re.compile(
|
||||||
|
r'(?:[^'
|
||||||
|
r'\u3007' # 〇
|
||||||
|
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||||
|
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||||
|
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||||
|
r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF]
|
||||||
|
r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F]
|
||||||
|
r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D]
|
||||||
|
r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F]
|
||||||
|
r'])+'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
RE_NSW = re.compile( # pragma: no cover
|
||||||
|
r'(?:[^'
|
||||||
|
r'\u3007' # 〇
|
||||||
|
r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF]
|
||||||
|
r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF]
|
||||||
|
r'\uf900-\ufaff' # CJK兼容:[F900-FAFF]
|
||||||
|
r'])+'
|
||||||
|
)
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import opencc
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .chronology import RE_TIME, RE_DATE, RE_DATE2
|
||||||
|
from .chronology import replace_time, replace_date, replace_date2
|
||||||
|
from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE
|
||||||
|
from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM
|
||||||
|
from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num
|
||||||
|
from .phone import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone
|
||||||
|
from .quantifier import RE_TEMPERATURE
|
||||||
|
from .quantifier import replace_temperature
|
||||||
|
|
||||||
|
|
||||||
|
class Normalizer():
|
||||||
|
def __init__(self):
|
||||||
|
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')
|
||||||
|
self._t2s_converter = opencc.OpenCC("t2s.json")
|
||||||
|
self._s2t_converter = opencc.OpenCC('s2t.json')
|
||||||
|
|
||||||
|
def _split(self, text: str) -> List[str]:
|
||||||
|
"""Split long text into sentences with sentence-splitting punctuations.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
text : str
|
||||||
|
The input text.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[str]
|
||||||
|
Sentences.
|
||||||
|
"""
|
||||||
|
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
|
||||||
|
text = text.strip()
|
||||||
|
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
def _tranditional_to_simplified(self, text: str) -> str:
|
||||||
|
return self._t2s_converter.convert(text)
|
||||||
|
|
||||||
|
def _simplified_to_traditional(self, text: str) -> str:
|
||||||
|
return self._s2t_converter.convert(text)
|
||||||
|
|
||||||
|
def normalize_sentence(self, sentence):
|
||||||
|
# basic character conversions
|
||||||
|
sentence = self._tranditional_to_simplified(sentence)
|
||||||
|
sentence = sentence.translate(F2H_ASCII_LETTERS).translate(
|
||||||
|
F2H_DIGITS).translate(F2H_SPACE)
|
||||||
|
|
||||||
|
# number related NSW verbalization
|
||||||
|
sentence = RE_DATE.sub(replace_date, sentence)
|
||||||
|
sentence = RE_DATE2.sub(replace_date2, sentence)
|
||||||
|
sentence = RE_TIME.sub(replace_time, sentence)
|
||||||
|
sentence = RE_TEMPERATURE.sub(replace_temperature, sentence)
|
||||||
|
sentence = RE_RANGE.sub(replace_range, sentence)
|
||||||
|
sentence = RE_FRAC.sub(replace_frac, sentence)
|
||||||
|
sentence = RE_PERCENTAGE.sub(replace_percentage, sentence)
|
||||||
|
sentence = RE_MOBILE_PHONE.sub(replace_phone, sentence)
|
||||||
|
sentence = RE_TELEPHONE.sub(replace_phone, sentence)
|
||||||
|
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)
|
||||||
|
sentence = RE_NUMBER.sub(replace_number, sentence)
|
||||||
|
|
||||||
|
return sentence
|
||||||
|
|
||||||
|
def normalize(self, text):
|
||||||
|
sentences = self._split(text)
|
||||||
|
sentences = [self.normalize_sentence(sent) for sent in sentences]
|
||||||
|
return sentences
|
|
@ -0,0 +1,158 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Rules to verbalize numbers into Chinese characters.
|
||||||
|
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')}
|
||||||
|
UNITS = OrderedDict({
|
||||||
|
1: '十',
|
||||||
|
2: '百',
|
||||||
|
3: '千',
|
||||||
|
4: '万',
|
||||||
|
8: '亿',
|
||||||
|
})
|
||||||
|
|
||||||
|
# 分数表达式
|
||||||
|
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
|
||||||
|
def replace_frac(match: re.Match) -> str:
|
||||||
|
sign = match.group(1)
|
||||||
|
nominator = match.group(2)
|
||||||
|
denominator = match.group(3)
|
||||||
|
sign: str = "负" if sign else ""
|
||||||
|
nominator: str = num2str(nominator)
|
||||||
|
denominator: str = num2str(denominator)
|
||||||
|
result = f"{sign}{denominator}分之{nominator}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 百分数表达式
|
||||||
|
RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
|
||||||
|
def replace_percentage(match: re.Match) -> str:
|
||||||
|
sign = match.group(1)
|
||||||
|
percent = match.group(2)
|
||||||
|
sign: str = "负" if sign else ""
|
||||||
|
percent: str = num2str(percent)
|
||||||
|
result = f"{sign}百分之{percent}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 整数表达式
|
||||||
|
# 带负号或者不带负号的整数 12, -10
|
||||||
|
RE_INTEGER = re.compile(
|
||||||
|
r'(-?)'
|
||||||
|
r'(\d+)'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 编号-无符号整形
|
||||||
|
# 00078
|
||||||
|
RE_DEFAULT_NUM = re.compile(r'\d{4}\d*')
|
||||||
|
def replace_default_num(match: re.Match):
|
||||||
|
number = match.group(0)
|
||||||
|
return verbalize_digit(number)
|
||||||
|
|
||||||
|
# 数字表达式
|
||||||
|
# 1. 整数: -10, 10;
|
||||||
|
# 2. 浮点数: 10.2, -0.3
|
||||||
|
# 3. 不带符号和整数部分的纯浮点数: .22, .38
|
||||||
|
RE_NUMBER = re.compile(
|
||||||
|
r'(-?)((\d+)(\.\d+)?)'
|
||||||
|
r'|(\.(\d+))'
|
||||||
|
)
|
||||||
|
def replace_number(match: re.Match) -> str:
|
||||||
|
sign = match.group(1)
|
||||||
|
number = match.group(2)
|
||||||
|
pure_decimal = match.group(5)
|
||||||
|
if pure_decimal:
|
||||||
|
result = num2str(pure_decimal)
|
||||||
|
else:
|
||||||
|
sign: str = "负" if sign else ""
|
||||||
|
number: str = num2str(number)
|
||||||
|
result = f"{sign}{number}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 范围表达式
|
||||||
|
# 12-23, 12~23
|
||||||
|
RE_RANGE = re.compile(
|
||||||
|
r'(\d+)[-~](\d+)'
|
||||||
|
)
|
||||||
|
def replace_range(match: re.Match) -> str:
|
||||||
|
first, second = match.group(1), match.group(2)
|
||||||
|
first: str = num2str(first)
|
||||||
|
second: str = num2str(second)
|
||||||
|
result = f"{first}到{second}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
|
||||||
|
stripped = value_string.lstrip('0')
|
||||||
|
if len(stripped) == 0:
|
||||||
|
return []
|
||||||
|
elif len(stripped) == 1:
|
||||||
|
if use_zero and len(stripped) < len(value_string):
|
||||||
|
return [DIGITS['0'], DIGITS[stripped]]
|
||||||
|
else:
|
||||||
|
return [DIGITS[stripped]]
|
||||||
|
else:
|
||||||
|
largest_unit = next(power for power in reversed(UNITS.keys()) if power < len(stripped))
|
||||||
|
first_part = value_string[:-largest_unit]
|
||||||
|
second_part = value_string[-largest_unit:]
|
||||||
|
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
|
||||||
|
|
||||||
|
def verbalize_cardinal(value_string: str) -> str:
|
||||||
|
if not value_string:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
# 000 -> '零' , 0 -> '零'
|
||||||
|
value_string = value_string.lstrip('0')
|
||||||
|
if len(value_string) == 0:
|
||||||
|
return DIGITS['0']
|
||||||
|
|
||||||
|
result_symbols = _get_value(value_string)
|
||||||
|
# verbalized number starting with '一十*' is abbreviated as `十*`
|
||||||
|
if len(result_symbols) >= 2 and result_symbols[0] == DIGITS['1'] and result_symbols[1] == UNITS[1]:
|
||||||
|
result_symbols = result_symbols[1:]
|
||||||
|
return ''.join(result_symbols)
|
||||||
|
|
||||||
|
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
||||||
|
result_symbols = [DIGITS[digit] for digit in value_string]
|
||||||
|
result = ''.join(result_symbols)
|
||||||
|
if alt_one:
|
||||||
|
result.replace("一", "幺")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def num2str(value_string: str) -> str:
|
||||||
|
integer_decimal = value_string.split('.')
|
||||||
|
if len(integer_decimal) == 1:
|
||||||
|
integer = integer_decimal[0]
|
||||||
|
decimal = ''
|
||||||
|
elif len(integer_decimal) == 2:
|
||||||
|
integer, decimal = integer_decimal
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The value string: '${value_string}' has more than one point in it.")
|
||||||
|
|
||||||
|
result = verbalize_cardinal(integer)
|
||||||
|
|
||||||
|
decimal = decimal.rstrip('0')
|
||||||
|
if decimal:
|
||||||
|
# '.22' is verbalized as '点二二'
|
||||||
|
# '3.20' is verbalized as '三点二
|
||||||
|
result += '点' + verbalize_digit(decimal)
|
||||||
|
return result
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .num import verbalize_digit
|
||||||
|
|
||||||
|
|
||||||
|
# 规范化固话/手机号码
|
||||||
|
# 手机
|
||||||
|
# http://www.jihaoba.com/news/show/13680
|
||||||
|
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
||||||
|
# 联通:130、131、132、156、155、186、185、176
|
||||||
|
# 电信:133、153、189、180、181、177
|
||||||
|
RE_MOBILE_PHONE= re.compile(
|
||||||
|
r"(?<!\d)((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})(?!\d)")
|
||||||
|
RE_TELEPHONE = re.compile(
|
||||||
|
r"(?<!\d)((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})(?!\d)")
|
||||||
|
|
||||||
|
|
||||||
|
def phone2str(phone_string: str, mobile=True) -> str:
|
||||||
|
if mobile:
|
||||||
|
sp_parts = phone_string.strip('+').split()
|
||||||
|
result = ''.join(
|
||||||
|
[verbalize_digit(part, alt_one=True) for part in sp_parts])
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
sil_parts = phone_string.split('-')
|
||||||
|
result = ''.join(
|
||||||
|
[verbalize_digit(part, alt_one=True) for part in sil_parts])
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_phone(match: re.Match) -> str:
|
||||||
|
return phone2str(match.group(0))
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .num import num2str
|
||||||
|
|
||||||
|
|
||||||
|
# 温度表达式,温度会影响负号的读法
|
||||||
|
# -3°C 零下三度
|
||||||
|
RE_TEMPERATURE = re.compile(
|
||||||
|
r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)'
|
||||||
|
)
|
||||||
|
def replace_temperature(match: re.Match) -> str:
|
||||||
|
sign = match.group(1)
|
||||||
|
temperature = match.group(2)
|
||||||
|
unit = match.group(3)
|
||||||
|
sign: str = "零下" if sign else ""
|
||||||
|
temperature: str = num2str(temperature)
|
||||||
|
unit: str = "摄氏度" if unit == "摄氏度" else "度"
|
||||||
|
result = f"{sign}{temperature}{unit}"
|
||||||
|
return result
|
|
@ -0,0 +1,159 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Generate lexicon and symbols for Mandarin Chinese phonology.
|
||||||
|
The lexicon is used for Montreal Force Aligner.
|
||||||
|
Note that syllables are used as word in this lexicon. Since syllables rather
|
||||||
|
than words are used in transcriptions produced by `reorganize_baker.py`.
|
||||||
|
We make this choice to better leverage other software for chinese text to
|
||||||
|
pinyin tools like pypinyin. This is the convention for G2P in Chinese.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
INITIALS = [
|
||||||
|
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
|
||||||
|
'r', 'z', 'c', 's', 'j', 'q', 'x'
|
||||||
|
]
|
||||||
|
|
||||||
|
FINALS = [
|
||||||
|
'a', 'ai', 'ao', 'an', 'ang', 'e', 'er', 'ei', 'en', 'eng', 'o', 'ou',
|
||||||
|
'ong', 'ii', 'iii', 'i', 'ia', 'iao', 'ian', 'iang', 'ie', 'io', 'iou',
|
||||||
|
'iong', 'in', 'ing', 'u', 'ua', 'uai', 'uan', 'uang', 'uei', 'uo', 'uen',
|
||||||
|
'ueng', 'v', 've', 'van', 'vn'
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIALS = ['sil', 'sp']
|
||||||
|
|
||||||
|
|
||||||
|
def rule(C, V, R, T):
|
||||||
|
"""Generate a syllable given the initial, the final, erhua indicator, and tone.
|
||||||
|
Orthographical rules for pinyin are applied. (special case for y, w, ui, un, iu)
|
||||||
|
|
||||||
|
Note that in this system, 'ü' is alway written as 'v' when appeared in phoneme, but converted to
|
||||||
|
'u' in syllables when certain conditions are satisfied.
|
||||||
|
|
||||||
|
'i' is distinguished when appeared in phonemes, and separated into 3 categories, 'i', 'ii' and 'iii'.
|
||||||
|
Erhua is is possibly applied to every finals, except for finals that already ends with 'r'.
|
||||||
|
When a syllable is impossible or does not have any characters with this pronunciation, return None
|
||||||
|
to filter it out.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 不可拼的音节, ii 只能和 z, c, s 拼
|
||||||
|
if V in ["ii"] and (C not in ['z', 'c', 's']):
|
||||||
|
return None
|
||||||
|
# iii 只能和 zh, ch, sh, r 拼
|
||||||
|
if V in ['iii'] and (C not in ['zh', 'ch', 'sh', 'r']):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 齐齿呼或者撮口呼不能和 f, g, k, h, zh, ch, sh, r, z, c, s
|
||||||
|
if (V not in ['ii', 'iii']) and V[0] in ['i', 'v'] and (
|
||||||
|
C in ['f', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'r', 'z', 'c', 's']):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 撮口呼只能和 j, q, x l, n 拼
|
||||||
|
if V.startswith("v"):
|
||||||
|
# v, ve 只能和 j ,q , x, n, l 拼
|
||||||
|
if V in ['v', 've']:
|
||||||
|
if C not in ['j', 'q', 'x', 'n', 'l', '']:
|
||||||
|
return None
|
||||||
|
# 其他只能和 j, q, x 拼
|
||||||
|
else:
|
||||||
|
if C not in ['j', 'q', 'x', '']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# j, q, x 只能和齐齿呼或者撮口呼拼
|
||||||
|
if (C in ['j', 'q', 'x']) and not (
|
||||||
|
(V not in ['ii', 'iii']) and V[0] in ['i', 'v']):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# b, p ,m, f 不能和合口呼拼,除了 u 之外
|
||||||
|
# bm p, m, f 不能和撮口呼拼
|
||||||
|
if (C in ['b', 'p', 'm', 'f']) and ((V[0] in ['u', 'v'] and V != "u") or
|
||||||
|
V == 'ong'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ua, uai, uang 不能和 d, t, n, l, r, z, c, s 拼
|
||||||
|
if V in ['ua', 'uai', 'uang'
|
||||||
|
] and C in ['d', 't', 'n', 'l', 'r', 'z', 'c', 's']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# sh 和 ong 不能拼
|
||||||
|
if V == 'ong' and C in ['sh']:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# o 和 gkh, zh ch sh r z c s 不能拼
|
||||||
|
if V == "o" and C in [
|
||||||
|
'd', 't', 'n', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'r', 'z', 'c', 's'
|
||||||
|
]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ueng 只是 weng 这个 ad-hoc 其他情况下都是 ong
|
||||||
|
if V == 'ueng' and C != '':
|
||||||
|
return
|
||||||
|
|
||||||
|
# 非儿化的 er 只能单独存在
|
||||||
|
if V == 'er' and C != '':
|
||||||
|
return None
|
||||||
|
|
||||||
|
if C == '':
|
||||||
|
if V in ["i", "in", "ing"]:
|
||||||
|
C = 'y'
|
||||||
|
elif V == 'u':
|
||||||
|
C = 'w'
|
||||||
|
elif V.startswith('i') and V not in ["ii", "iii"]:
|
||||||
|
C = 'y'
|
||||||
|
V = V[1:]
|
||||||
|
elif V.startswith('u'):
|
||||||
|
C = 'w'
|
||||||
|
V = V[1:]
|
||||||
|
elif V.startswith('v'):
|
||||||
|
C = 'yu'
|
||||||
|
V = V[1:]
|
||||||
|
else:
|
||||||
|
if C in ['j', 'q', 'x']:
|
||||||
|
if V.startswith('v'):
|
||||||
|
V = re.sub('v', 'u', V)
|
||||||
|
if V == 'iou':
|
||||||
|
V = 'iu'
|
||||||
|
elif V == 'uei':
|
||||||
|
V = 'ui'
|
||||||
|
elif V == 'uen':
|
||||||
|
V = 'un'
|
||||||
|
result = C + V
|
||||||
|
|
||||||
|
# Filter er 不能再儿化
|
||||||
|
if result.endswith('r') and R == 'r':
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ii and iii, change back to i
|
||||||
|
result = re.sub(r'i+', 'i', result)
|
||||||
|
|
||||||
|
result = result + R + T
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lexicon(with_tone=False, with_erhua=False):
|
||||||
|
"""Generate lexicon for Mandarin Chinese."""
|
||||||
|
syllables = OrderedDict()
|
||||||
|
|
||||||
|
for C in [''] + INITIALS:
|
||||||
|
for V in FINALS:
|
||||||
|
for R in [''] if not with_erhua else ['', 'r']:
|
||||||
|
for T in [''] if not with_tone else ['1', '2', '3', '4', '5']:
|
||||||
|
result = rule(C, V, R, T)
|
||||||
|
if result:
|
||||||
|
syllables[result] = f'{C} {V}{R}{T}'
|
||||||
|
return syllables
|
|
@ -0,0 +1,235 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from pypinyin import Style
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedTone():
|
||||||
|
def __init__(self):
|
||||||
|
self.must_neural_tone_words = {'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', '难为', '队伍',
|
||||||
|
'阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', '里头', '部分', '那么', '道士',
|
||||||
|
'造化', '迷糊', '连累', '这么', '这个', '运气', '过去', '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄',
|
||||||
|
'财主', '豆腐', '讲究', '记性', '记号', '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门',
|
||||||
|
'街坊', '行李', '行当', '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇',
|
||||||
|
'芝麻', '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂', '胡萝',
|
||||||
|
'胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆', '老头', '老太', '翻腾',
|
||||||
|
'罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂', '精神', '粮食', '簸箕', '篱笆', '算计',
|
||||||
|
'算盘', '答应', '笤帚', '笑语', '笑话', '窟窿', '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气',
|
||||||
|
'秀才', '福气', '祖宗', '砚台', '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛',
|
||||||
|
'相声', '盘算', '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
|
||||||
|
'琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快', '爱人', '热闹',
|
||||||
|
'烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜', '溜达', '温和', '清楚', '消息',
|
||||||
|
'浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔', '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火',
|
||||||
|
'架势', '枕头', '枇杷', '机灵', '本事', '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候',
|
||||||
|
'新鲜', '故事', '收拾', '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌',
|
||||||
|
'招呼', '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实', '扁担',
|
||||||
|
'戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', '念叨', '快活', '忙活',
|
||||||
|
'志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', '干事', '帮手', '帐篷', '希罕', '师父',
|
||||||
|
'师傅', '巴结', '巴掌', '差事', '工夫', '岁数', '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头',
|
||||||
|
'对付', '寡妇', '家伙', '客气', '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家',
|
||||||
|
'娘家', '委屈', '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
|
||||||
|
'大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴', '嘱咐', '嘟囔',
|
||||||
|
'嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦', '咳嗽', '和尚', '告诉', '告示',
|
||||||
|
'含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝', '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱',
|
||||||
|
'包涵', '匀称', '勤快', '动静', '动弹', '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索',
|
||||||
|
'利害', '分析', '出息', '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜',
|
||||||
|
'使唤', '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家', '交情',
|
||||||
|
'云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', '不由', '不在', '下水',
|
||||||
|
'下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个'}
|
||||||
|
|
||||||
|
def _neural_tone(self, word, pos, sub_finals):
|
||||||
|
ge_idx = word.find("个")
|
||||||
|
if len(word) == 1 and word in "吧呢啊嘛" and pos == 'y':
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
elif len(word) == 1 and word in "的地得" and pos in {"ud", "uj", "uv"}:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
# eg: 走了, 看着, 去过
|
||||||
|
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"}:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
# eg: 桌上, 地下, 家里
|
||||||
|
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
# eg: 上来, 下去
|
||||||
|
elif len(word) > 1 and word[-1] in "来去" and pos[0] in {"v"}:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
# 个做量词
|
||||||
|
elif ge_idx >= 1 and word[ge_idx - 1].isnumeric():
|
||||||
|
sub_finals[ge_idx] = sub_finals[ge_idx][:-1] + "5"
|
||||||
|
# reduplication words for n. and v. eg: 奶奶, 试试
|
||||||
|
elif len(word) >= 2 and word[-1] == word[-2] and pos[0] in {"n", "v"}:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
# conventional tone5 in Chinese
|
||||||
|
elif word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
|
||||||
|
sub_finals[-1] = sub_finals[-1][:-1] + "5"
|
||||||
|
|
||||||
|
return sub_finals
|
||||||
|
|
||||||
|
def _bu_tone(self, word, sub_finals):
|
||||||
|
# "不" before tone4 should be bu2, eg: 不怕
|
||||||
|
if len(word) > 1 and word[0] == "不" and sub_finals[1][-1] == "4":
|
||||||
|
sub_finals[0] = sub_finals[0][:-1] + "2"
|
||||||
|
# eg: 看不懂
|
||||||
|
elif len(word) == 3 and word[1] == "不":
|
||||||
|
sub_finals[1] = sub_finals[1][:-1] + "5"
|
||||||
|
|
||||||
|
return sub_finals
|
||||||
|
|
||||||
|
def _yi_tone(self, word, sub_finals):
|
||||||
|
# "一" in number sequences, eg: 一零零
|
||||||
|
if len(word) > 1 and word[0] == "一" and all([item.isnumeric() for item in word]):
|
||||||
|
return sub_finals
|
||||||
|
# "一" before tone4 should be yi2, eg: 一段
|
||||||
|
elif len(word) > 1 and word[0] == "一" and sub_finals[1][-1] == "4":
|
||||||
|
sub_finals[0] = sub_finals[0][:-1] + "2"
|
||||||
|
# "一" before non-tone4 should be yi4, eg: 一天
|
||||||
|
elif len(word) > 1 and word[0] == "一" and sub_finals[1][-1]!= "4":
|
||||||
|
sub_finals[0] = sub_finals[0][:-1] + "4"
|
||||||
|
# "一" beturn reduplication words shold be yi5, eg: 看一看
|
||||||
|
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
||||||
|
sub_finals[1] = sub_finals[1][:-1] + "5"
|
||||||
|
# when "一" is oedinal word, it should be yi1
|
||||||
|
elif word.startswith("第一"):
|
||||||
|
sub_finals[1] = sub_finals[1][:-1] + "1"
|
||||||
|
return sub_finals
|
||||||
|
|
||||||
|
# 我给你讲个故事 没处理
|
||||||
|
def _three_tone(self, word, sub_finals):
|
||||||
|
if len(word) == 2 and self._all_tone_three(sub_finals):
|
||||||
|
sub_finals[0] = sub_finals[0][:-1] + "2"
|
||||||
|
elif len(word) == 3:
|
||||||
|
word_list = jieba.cut_for_search(word)
|
||||||
|
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
|
||||||
|
new_word_list = []
|
||||||
|
first_subword = word_list[0]
|
||||||
|
first_begin_idx = word.find(first_subword)
|
||||||
|
if first_begin_idx == 0:
|
||||||
|
second_subword = word[len(first_subword):]
|
||||||
|
new_word_list = [first_subword, second_subword]
|
||||||
|
else:
|
||||||
|
second_subword = word[:-len(first_subword)]
|
||||||
|
|
||||||
|
new_word_list = [second_subword, first_subword]
|
||||||
|
if self._all_tone_three(sub_finals):
|
||||||
|
# disyllabic + monosyllabic, eg: 蒙古/包
|
||||||
|
if len(new_word_list[0]) == 2:
|
||||||
|
sub_finals[0] = sub_finals[0][:-1] + "2"
|
||||||
|
sub_finals[1] = sub_finals[1][:-1] + "2"
|
||||||
|
# monosyllabic + disyllabic, eg: 纸/老虎
|
||||||
|
elif len(new_word_list[0]) == 1:
|
||||||
|
sub_finals[1] = sub_finals[1][:-1] + "2"
|
||||||
|
else:
|
||||||
|
sub_finals_list = [sub_finals[:len(new_word_list[0])], sub_finals[len(new_word_list[0]):]]
|
||||||
|
if len(sub_finals_list) == 2:
|
||||||
|
for i, sub in enumerate(sub_finals_list):
|
||||||
|
# eg: 所有/人
|
||||||
|
if self._all_tone_three(sub) and len(sub) == 2:
|
||||||
|
sub_finals_list[i][0] = sub_finals_list[i][0][:-1] + "2"
|
||||||
|
# eg: 好/喜欢
|
||||||
|
elif i == 1 and not self._all_tone_three(sub) and sub_finals_list[i][0][-1] == "3" and \
|
||||||
|
sub_finals_list[0][-1][-1] == "3":
|
||||||
|
|
||||||
|
sub_finals_list[0][-1] = sub_finals_list[0][-1][:-1] + "2"
|
||||||
|
sub_finals = sum(sub_finals_list, [])
|
||||||
|
# split idiom into two words who's length is 2
|
||||||
|
elif len(word) == 4:
|
||||||
|
sub_finals_list = [sub_finals[:2], sub_finals[2:]]
|
||||||
|
sub_finals = []
|
||||||
|
for sub in sub_finals_list:
|
||||||
|
if self._all_tone_three(sub):
|
||||||
|
sub[0] = sub[0][:-1] + "2"
|
||||||
|
sub_finals += sub
|
||||||
|
|
||||||
|
return sub_finals
|
||||||
|
|
||||||
|
def _all_tone_three(self, finals):
|
||||||
|
return all(x[-1] == "3" for x in finals)
|
||||||
|
|
||||||
|
# merge "不" and the word behind it
|
||||||
|
def _merge_bu(self, seg):
|
||||||
|
new_seg = []
|
||||||
|
last_word = ""
|
||||||
|
for word, pos in seg:
|
||||||
|
if last_word == "不":
|
||||||
|
word = last_word + word
|
||||||
|
if word != "不":
|
||||||
|
new_seg.append((word, pos))
|
||||||
|
last_word = word[:]
|
||||||
|
if last_word == "不":
|
||||||
|
new_seg.append((last_word, 'd'))
|
||||||
|
last_word = ""
|
||||||
|
seg = new_seg
|
||||||
|
return seg
|
||||||
|
|
||||||
|
# function 1: merge "一" and reduplication words in it's left and right,eg: "看","一","看" ->"看一看"
|
||||||
|
# function 2: merge single "一" and the word behind it
|
||||||
|
def _merge_yi(self, seg):
|
||||||
|
new_seg = []
|
||||||
|
# function 1
|
||||||
|
for i, (word, pos) in enumerate(seg):
|
||||||
|
if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][0] == seg[i + 1][0] and seg[i - 1][
|
||||||
|
1] == "v":
|
||||||
|
new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
|
||||||
|
else:
|
||||||
|
if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][0] == word and pos == "v":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
seg = new_seg
|
||||||
|
new_seg = []
|
||||||
|
# function 2
|
||||||
|
for i, (word, pos) in enumerate(seg):
|
||||||
|
if new_seg and new_seg[-1][0] == "一":
|
||||||
|
new_seg[-1][0] = new_seg[-1][0] + word
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
|
||||||
|
seg = new_seg
|
||||||
|
return seg
|
||||||
|
|
||||||
|
def _merge_continuous_three_tones(self, seg):
|
||||||
|
new_seg = []
|
||||||
|
sub_finals_list = [lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) for (word, pos)
|
||||||
|
in seg]
|
||||||
|
assert len(sub_finals_list) == len(seg)
|
||||||
|
merge_last = [False] * len(seg)
|
||||||
|
for i, (word, pos) in enumerate(seg):
|
||||||
|
if i - 1 >= 0 and self._all_tone_three(sub_finals_list[i - 1]) and self._all_tone_three(
|
||||||
|
sub_finals_list[i]) and not merge_last[i - 1]:
|
||||||
|
if len(seg[i - 1][0]) + len(seg[i][0]) <= 3:
|
||||||
|
new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
|
||||||
|
merge_last[i] = True
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
else:
|
||||||
|
new_seg.append([word, pos])
|
||||||
|
seg = new_seg
|
||||||
|
return seg
|
||||||
|
|
||||||
|
def pre_merge_for_modify(self, seg):
|
||||||
|
seg = self._merge_bu(seg)
|
||||||
|
seg = self._merge_yi(seg)
|
||||||
|
seg = self._merge_continuous_three_tones(seg)
|
||||||
|
return seg
|
||||||
|
|
||||||
|
def modified_tone(self, word, pos, finals):
|
||||||
|
finals = self._bu_tone(word, finals)
|
||||||
|
finals = self._yi_tone(word, finals)
|
||||||
|
finals = self._neural_tone(word, pos, finals)
|
||||||
|
finals = self._three_tone(word, finals)
|
||||||
|
return finals
|
|
@ -415,14 +415,7 @@ class FastSpeech2(nn.Layer):
|
||||||
"""
|
"""
|
||||||
x, y = text, speech
|
x, y = text, speech
|
||||||
d, p, e = durations, pitch, energy
|
d, p, e = durations, pitch, energy
|
||||||
|
x = paddle.to_tensor(text)
|
||||||
# add eos at the last of sequence
|
|
||||||
x = np.pad(text.numpy(),
|
|
||||||
pad_width=((0, 1)),
|
|
||||||
mode="constant",
|
|
||||||
constant_values=self.eos)
|
|
||||||
|
|
||||||
x = paddle.to_tensor(x)
|
|
||||||
|
|
||||||
# setup batch axis
|
# setup batch axis
|
||||||
ilens = paddle.to_tensor(
|
ilens = paddle.to_tensor(
|
||||||
|
|
Loading…
Reference in New Issue