fix scripts

This commit is contained in:
TianYuan 2021-08-03 10:10:39 +00:00
parent 6aeb56301f
commit a141d39b38
19 changed files with 228 additions and 149 deletions

View File

@ -30,8 +30,8 @@ Download pretrained parallel wavegan model from [parallel_wavegan_baker_ckpt_1.0
```bash
unzip parallel_wavegan_baker_ckpt_1.0.zip
```
`synthesize.sh` can synthesize waveform for `metadata.jsonl`.
`synthesize_e2e.sh` can synthesize waveform for text list.
`synthesize.sh` can synthesize waveform from `metadata.jsonl`.
`synthesize_e2e.sh` can synthesize waveform from text list.
```bash
./synthesize.sh

View File

@ -21,33 +21,33 @@ from parakeet.frontend.cn_frontend import Frontend as cnFrontend
class Frontend():
def __init__(self, phone_vocab_path=None, tone_vocab_path=None):
self.frontend = cnFrontend()
self.voc_phones = {}
self.voc_tones = {}
self.vocab_phones = {}
self.vocab_tones = {}
if phone_vocab_path:
with open(phone_vocab_path, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
self.voc_phones[phn] = int(id)
self.vocab_phones[phn] = int(id)
if tone_vocab_path:
with open(tone_vocab_path, 'rt') as f:
tone_id = [line.strip().split() for line in f.readlines()]
for tone, id in tone_id:
self.voc_tones[tone] = int(id)
self.vocab_tones[tone] = int(id)
def _p2id(self, phonemes):
# replace unk phone with sp
phonemes = [
phn if phn in self.voc_phones else "sp" for phn in phonemes
phn if phn in self.vocab_phones else "sp" for phn in phonemes
]
phone_ids = [self.voc_phones[item] for item in phonemes]
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def _t2id(self, tones):
# replace unk phone with sp
tones = [
tone if tone in self.voc_tones else "0" for tone in tones
tone if tone in self.vocab_tones else "0" for tone in tones
]
tone_ids = [self.voc_tones[item] for item in tones]
tone_ids = [self.vocab_tones[item] for item in tones]
return np.array(tone_ids, np.int64)
def get_input_ids(self, sentence, get_tone_ids=False):
@ -55,7 +55,7 @@ class Frontend():
result = {}
phones = []
tones = []
if get_tone_ids and self.voc_tones:
if get_tone_ids and self.vocab_tones:
for full_phone in phonemes:
# split tone from finals
match = re.match(r'^(\w+)([012345])$', full_phone)

View File

@ -21,20 +21,29 @@ from config import get_cfg_default
class LogMelFBank():
def __init__(self, conf):
self.sr = conf.fs
def __init__(self,
sr=24000,
n_fft=2048,
hop_length=300,
win_length=1200,
window="hann",
n_mels=80,
fmin=80,
fmax=7600,
eps=1e-10):
self.sr = sr
# stft
self.n_fft = conf.n_fft
self.win_length = conf.win_length
self.hop_length = conf.n_shift
self.window = conf.window
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.window = window
self.center = True
self.pad_mode = "reflect"
# mel
self.n_mels = conf.n_mels
self.fmin = conf.fmin
self.fmax = conf.fmax
self.n_mels = n_mels
self.fmin = fmin
self.fmax = fmax
self.mel_filter = self._create_mel_filter()
@ -66,6 +75,10 @@ class LogMelFBank():
mel = np.dot(self.mel_filter, S)
return mel
# We use different definition for log-spec between TTS and ASR
# TTS: log_10(abs(stft))
# ASR: log_e(power(stft))
def get_log_mel_fbank(self, wav):
mel = self._mel_spectrogram(wav)
mel = np.clip(mel, a_min=1e-10, a_max=float("inf"))
@ -75,12 +88,17 @@ class LogMelFBank():
class Pitch():
def __init__(self, conf):
def __init__(self,
sr=24000,
hop_length=300,
f0min=80,
f0max=7600
):
self.sr = conf.fs
self.hop_length = conf.n_shift
self.f0min = conf.f0min
self.f0max = conf.f0max
self.sr = sr
self.hop_length = hop_length
self.f0min = f0min
self.f0max = f0max
def _convert_to_continuous_f0(self, f0: np.array) -> np.array:
if (f0 == 0).all():
@ -132,6 +150,7 @@ class Pitch():
arr[mask] = 0
avg_arr = np.mean(arr, axis=0) if len(arr) != 0 else np.array(0)
arr_list.append(avg_arr)
# shape (T,1)
arr_list = np.expand_dims(np.array(arr_list), 0).T
return arr_list
@ -149,15 +168,22 @@ class Pitch():
class Energy():
def __init__(self, conf):
def __init__(self,
sr=24000,
n_fft=2048,
hop_length=300,
win_length=1200,
window="hann",
center=True,
pad_mode="reflect"):
self.sr = conf.fs
self.n_fft = conf.n_fft
self.win_length = conf.win_length
self.hop_length = conf.n_shift
self.window = conf.window
self.center = True
self.pad_mode = "reflect"
self.sr = sr
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.window = window
self.center = center
self.pad_mode = pad_mode
def _stft(self, wav):
D = librosa.core.stft(
@ -173,7 +199,7 @@ class Energy():
def _calculate_energy(self, input):
input = input.astype(np.float32)
input_stft = self._stft(input)
input_power = np.abs(input_stft)**2
input_power = np.abs(input_stft) ** 2
energy = np.sqrt(
np.clip(
np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float(
@ -187,6 +213,7 @@ class Energy():
arr = input[start:end]
avg_arr = np.mean(arr, axis=0) if len(arr) != 0 else np.array(0)
arr_list.append(avg_arr)
# shape (T,1)
arr_list = np.expand_dims(np.array(arr_list), 0).T
return arr_list
@ -201,19 +228,34 @@ if __name__ == "__main__":
C = get_cfg_default()
filename = "../raw_data/data/format.1/000001.flac"
wav, _ = librosa.load(filename, sr=C.fs)
mel_extractor = LogMelFBank(C)
mel_extractor = LogMelFBank(
sr=C.fs,
n_fft=C.n_fft,
hop_length=C.n_shift,
win_length=C.win_length,
window=C.window,
n_mels=C.n_mels,
fmin=C.fmin,
fmax=C.fmax, )
mel = mel_extractor.get_log_mel_fbank(wav)
print(mel)
print(mel.shape)
pitch_extractor = Pitch(C)
pitch_extractor = Pitch(sr=C.fs,
hop_length=C.n_shift,
f0min=C.f0min,
f0max=C.f0max)
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
duration = np.array([int(x) for x in duration.split(" ")])
avg_f0 = pitch_extractor.get_pitch(wav, duration=duration)
print(avg_f0)
print(avg_f0.shape)
energy_extractor = Energy(C)
energy_extractor = Energy(sr=C.fs,
n_fft=C.n_fft,
hop_length=C.n_shift,
win_length=C.win_length,
window=C.window)
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
duration = np.array([int(x) for x in duration.split(" ")])
avg_energy = energy_extractor.get_energy(wav, duration=duration)

View File

@ -60,7 +60,7 @@ def main():
required=True,
help="energy statistics file.")
parser.add_argument(
"--phones",
"--phones-dict",
type=str,
default="phone_id_map.txt ",
help="phone vocabulary file.")
@ -128,11 +128,11 @@ def main():
energy_scaler.scale_ = np.load(args.energy_stats)[1]
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
voc_phones = {}
with open(args.phones, 'rt') as f:
vocab_phones = {}
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
voc_phones[phn] = int(id)
vocab_phones[phn] = int(id)
# process each file
output_metadata = []
@ -160,7 +160,7 @@ def main():
energy_dir.mkdir(parents=True, exist_ok=True)
energy_path = energy_dir / f"{utt_id}_energy.npy"
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
phone_ids = [voc_phones[p] for p in item['phones']]
phone_ids = [vocab_phones[p] for p in item['phones']]
record = {
"utt_id": item['utt_id'],
"text": phone_ids,

View File

@ -96,18 +96,18 @@ def get_input_token(sentence, output_path):
output_path : str or path
path to save phone_id_map
'''
phn_emb = set()
phn_token = set()
for utt in sentence:
for phn in sentence[utt][0]:
if phn != "<eos>":
phn_emb.add(phn)
phn_emb = list(phn_emb)
phn_emb.sort()
phn_emb = ["<pad>", "<unk>"] + phn_emb
phn_emb += ["", "", "", "", "<eos>"]
phn_token.add(phn)
phn_token = list(phn_token)
phn_token.sort()
phn_token = ["<pad>", "<unk>"] + phn_token
phn_token += ["", "", "", "", "<eos>"]
f = open(output_path, 'w')
for i, phn in enumerate(phn_emb):
for i, phn in enumerate(phn_token):
f.write(phn + ' ' + str(i) + '\n')
f.close()
@ -284,8 +284,10 @@ def main():
help="logging level. higher is more logging. (default=1)")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")
def str2bool(str):
return True if str.lower() == 'true' else False
parser.add_argument(
"--cut-sil", type=str2bool, default=True, help="whether cut sil in the edge of audio")
args = parser.parse_args()
@ -298,7 +300,7 @@ def main():
if args.verbose > 1:
print(vars(args))
print(C)
root_dir = Path(args.rootdir).expanduser()
dumpdir = Path(args.dumpdir).expanduser()
dumpdir.mkdir(parents=True, exist_ok=True)
@ -325,12 +327,27 @@ def main():
test_dump_dir.mkdir(parents=True, exist_ok=True)
# Extractor
mel_extractor = LogMelFBank(C)
pitch_extractor = Pitch(C)
energy_extractor = Energy(C)
mel_extractor = LogMelFBank(
sr=C.fs,
n_fft=C.n_fft,
hop_length=C.n_shift,
win_length=C.win_length,
window=C.window,
n_mels=C.n_mels,
fmin=C.fmin,
fmax=C.fmax)
pitch_extractor = Pitch(sr=C.fs,
hop_length=C.n_shift,
f0min=C.f0min,
f0max=C.f0max)
energy_extractor = Energy(sr=C.fs,
n_fft=C.n_fft,
hop_length=C.n_shift,
win_length=C.win_length,
window=C.window)
# process for the 3 sections
process_sentences(
C,
train_wav_files,
@ -350,7 +367,7 @@ def main():
pitch_extractor,
energy_extractor,
cut_sil=args.cut_sil)
process_sentences(
C,
test_wav_files,

View File

@ -12,7 +12,7 @@ python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-n
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="energy"
# normalize and covert phone to id, dev and test should use train's stats
python3 normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
python3 normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones-dict dump/phone_id_map.txt
python3 normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones-dict dump/phone_id_map.txt
python3 normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --speech-stats=dump/train/speech_stats.npy --pitch-stats=dump/train/pitch_stats.npy --energy-stats=dump/train/energy_stats.npy --phones-dict dump/phone_id_map.txt

View File

@ -6,4 +6,4 @@ python3 train.py \
--config=conf/default.yaml \
--output-dir=exp/default \
--nprocs=1 \
--phones=dump/phone_id_map.txt
--phones-dict=dump/phone_id_map.txt

View File

@ -37,7 +37,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
test_metadata = list(reader)
test_dataset = DataTable(data=test_metadata, fields=["utt_id", "text"])
with open(args.phones, "r") as f:
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
@ -119,7 +119,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--phones",
"--phones-dict",
type=str,
default="phone_id_map.txt ",
help="phone vocabulary file.")

View File

@ -11,4 +11,4 @@ python3 synthesize.py \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=exp/debug/test \
--device="gpu" \
--phones=dump/phone_id_map.txt
--phones-dict=dump/phone_id_map.txt

View File

@ -39,7 +39,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
utt_id, sentence = line.strip().split()
sentences.append((utt_id, sentence))
with open(args.phones, "r") as f:
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
@ -57,7 +57,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
vocoder.eval()
print("model done!")
frontend = Frontend(args.phones)
frontend = Frontend(args.phones_dict)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
@ -124,7 +124,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--phones",
"--phones-dict",
type=str,
default="phone_id_map.txt ",
help="phone vocabulary file.")

View File

@ -11,4 +11,4 @@ python3 synthesize_e2e.py \
--text=sentences.txt \
--output-dir=exp/debug/test_e2e \
--device="gpu" \
--phones=dump/phone_id_map.txt
--phones-dict=dump/phone_id_map.txt

View File

@ -131,7 +131,7 @@ def train_sp(args, config):
num_workers=config.num_workers)
print("dataloaders done!")
with open(args.phones, "r") as f:
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
@ -182,7 +182,7 @@ def main():
"--nprocs", type=int, default=1, help="number of processes")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
parser.add_argument(
"--phones",
"--phones-dict",
type=str,
default="phone_id_map.txt ",
help="phone vocabulary file.")

View File

@ -18,8 +18,8 @@ import numpy as np
import paddle
import re
from g2pM import G2pM
from parakeet.frontend.modified_tone import ModifiedTone
from parakeet.frontend.cn_normalization.normalization import Normalizer
from parakeet.frontend.tone_sandhi import ToneSandhi
from parakeet.frontend.cn_normalization.text_normlization import TextNormalizer
from pypinyin import lazy_pinyin, Style
from parakeet.frontend.generate_lexicon import generate_lexicon
@ -27,8 +27,8 @@ from parakeet.frontend.generate_lexicon import generate_lexicon
class Frontend():
def __init__(self, g2p_model="pypinyin"):
self.tone_modifier = ModifiedTone()
self.normalizer = Normalizer()
self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!"
# g2p_model can be pypinyin and g2pM
self.g2p_model = g2p_model
@ -65,6 +65,7 @@ class Frontend():
initials.append('')
finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
initials.append(pinyin)
finals.append(pinyin)
return initials, finals
@ -96,7 +97,7 @@ class Frontend():
phones.append(c)
if v and v not in self.punc:
phones.append(v)
# add sp between sentence
# add sp between sentence (replace the last punc with sp)
if initials[-1] in self.punc:
phones.append('sp')
phones_list.append(phones)
@ -105,6 +106,6 @@ class Frontend():
return phones_list
def get_phonemes(self, sentence):
sentences = self.normalizer.normalize(sentence)
sentences = self.text_normalizer.normalize(sentence)
phonemes = self._g2p(sentences)
return phonemes

View File

@ -0,0 +1,11 @@
supported NSW (Non-Standard-Word) Normalization
|NSW type|raw|normalized|
|-|-|-|
|cardinal|这块黄金重达324.75克|这块黄金重达三百二十四点七五克|
|date|她出生于86年8月18日她弟弟出生于1995年3月1日|她出生于八六年八月十八日 她弟弟出生于一九九五年三月一日|
|digit|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九|
|fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票|
|money|随便来几个价格12块534.5元20.1万|随便来几个价格十二块五 三十四点五元 二十点一万|
|percentage|明天有62的概率降雨|明天有百分之六十二的概率降雨|
|telephone|这是固话0421-33441122<br>这是手机+86 18544139121|这是固话零四二一三三四四一一二二<br>这是手机八六一八五四四一三九一二一|

View File

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parakeet.frontend.cn_normalization.normalization import *
from parakeet.frontend.cn_normalization.text_normlization import *

View File

@ -21,12 +21,12 @@ from .chronology import replace_time, replace_date, replace_date2
from .constants import F2H_ASCII_LETTERS, F2H_DIGITS, F2H_SPACE
from .num import RE_NUMBER, RE_FRAC, RE_PERCENTAGE, RE_RANGE, RE_INTEGER, RE_DEFAULT_NUM
from .num import replace_number, replace_frac, replace_percentage, replace_range, replace_default_num
from .phone import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone
from .phonecode import RE_MOBILE_PHONE, RE_TELEPHONE, replace_phone
from .quantifier import RE_TEMPERATURE
from .quantifier import replace_temperature
class Normalizer():
class TextNormalizer():
def __init__(self):
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')
self._t2s_converter = opencc.OpenCC("t2s.json")

View File

@ -18,7 +18,7 @@ from pypinyin import lazy_pinyin
from pypinyin import Style
class ModifiedTone():
class ToneSandhi():
def __init__(self):
self.must_neural_tone_words = {'麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝', '难为', '队伍',
'阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊', '里头', '部分', '那么', '道士',
@ -51,67 +51,71 @@ class ModifiedTone():
'云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故', '不由', '不在', '下水',
'下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个'}
def _neural_tone(self, word, pos, sub_finals):
# the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
# e.g.
# word: "家里"
# pos: "s"
# finals: ['ia1', 'i3']
def _neural_sandhi(self, word, pos, finals):
ge_idx = word.find("")
if len(word) == 1 and word in "吧呢啊嘛" and pos == 'y':
sub_finals[-1] = sub_finals[-1][:-1] + "5"
finals[-1] = finals[-1][:-1] + "5"
elif len(word) == 1 and word in "的地得" and pos in {"ud", "uj", "uv"}:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
# eg: 走了, 看着, 去过
finals[-1] = finals[-1][:-1] + "5"
# e.g. 走了, 看着, 去过
elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
finals[-1] = finals[-1][:-1] + "5"
elif len(word) > 1 and word[-1] in "们子" and pos in {"r", "n"}:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
# eg: 桌上, 地下, 家里
finals[-1] = finals[-1][:-1] + "5"
# e.g. 桌上, 地下, 家里
elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
# eg: 上来, 下去
finals[-1] = finals[-1][:-1] + "5"
# e.g. 上来, 下去
elif len(word) > 1 and word[-1] in "来去" and pos[0] in {"v"}:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
finals[-1] = finals[-1][:-1] + "5"
# 个做量词
elif ge_idx >= 1 and word[ge_idx - 1].isnumeric():
sub_finals[ge_idx] = sub_finals[ge_idx][:-1] + "5"
# reduplication words for n. and v. eg: 奶奶, 试试
finals[ge_idx] = finals[ge_idx][:-1] + "5"
# reduplication words for n. and v. e.g. 奶奶, 试试
elif len(word) >= 2 and word[-1] == word[-2] and pos[0] in {"n", "v"}:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
finals[-1] = finals[-1][:-1] + "5"
# conventional tone5 in Chinese
elif word in self.must_neural_tone_words or word[-2:] in self.must_neural_tone_words:
sub_finals[-1] = sub_finals[-1][:-1] + "5"
finals[-1] = finals[-1][:-1] + "5"
return sub_finals
return finals
def _bu_tone(self, word, sub_finals):
# "不" before tone4 should be bu2, eg: 不怕
if len(word) > 1 and word[0] == "" and sub_finals[1][-1] == "4":
sub_finals[0] = sub_finals[0][:-1] + "2"
# eg: 看不懂
def _bu_sandhi(self, word, finals):
# "不" before tone4 should be bu2, e.g. 不怕
if len(word) > 1 and word[0] == "" and finals[1][-1] == "4":
finals[0] = finals[0][:-1] + "2"
# e.g. 看不懂
elif len(word) == 3 and word[1] == "":
sub_finals[1] = sub_finals[1][:-1] + "5"
finals[1] = finals[1][:-1] + "5"
return sub_finals
return finals
def _yi_tone(self, word, sub_finals):
# "一" in number sequences, eg: 一零零
def _yi_sandhi(self, word, finals):
# "一" in number sequences, e.g. 一零零
if len(word) > 1 and word[0] == "" and all([item.isnumeric() for item in word]):
return sub_finals
# "一" before tone4 should be yi2, eg: 一段
elif len(word) > 1 and word[0] == "" and sub_finals[1][-1] == "4":
sub_finals[0] = sub_finals[0][:-1] + "2"
# "一" before non-tone4 should be yi4, eg: 一天
elif len(word) > 1 and word[0] == "" and sub_finals[1][-1]!= "4":
sub_finals[0] = sub_finals[0][:-1] + "4"
# "一" beturn reduplication words shold be yi5, eg: 看一看
return finals
# "一" before tone4 should be yi2, e.g. 一段
elif len(word) > 1 and word[0] == "" and finals[1][-1] == "4":
finals[0] = finals[0][:-1] + "2"
# "一" before non-tone4 should be yi4, e.g. 一天
elif len(word) > 1 and word[0] == "" and finals[1][-1]!= "4":
finals[0] = finals[0][:-1] + "4"
# "一" between reduplication words shold be yi5, e.g. 看一看
elif len(word) == 3 and word[1] == "" and word[0] == word[-1]:
sub_finals[1] = sub_finals[1][:-1] + "5"
# when "一" is oedinal word, it should be yi1
finals[1] = finals[1][:-1] + "5"
# when "一" is ordinal word, it should be yi1
elif word.startswith("第一"):
sub_finals[1] = sub_finals[1][:-1] + "1"
return sub_finals
finals[1] = finals[1][:-1] + "1"
return finals
# 我给你讲个故事 没处理
def _three_tone(self, word, sub_finals):
if len(word) == 2 and self._all_tone_three(sub_finals):
sub_finals[0] = sub_finals[0][:-1] + "2"
def _three_sandhi(self, word, finals):
if len(word) == 2 and self._all_tone_three(finals):
finals[0] = finals[0][:-1] + "2"
elif len(word) == 3:
word_list = jieba.cut_for_search(word)
word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
@ -125,42 +129,43 @@ class ModifiedTone():
second_subword = word[:-len(first_subword)]
new_word_list = [second_subword, first_subword]
if self._all_tone_three(sub_finals):
# disyllabic + monosyllabic, eg: 蒙古/包
if self._all_tone_three(finals):
# disyllabic + monosyllabic, e.g. 蒙古/包
if len(new_word_list[0]) == 2:
sub_finals[0] = sub_finals[0][:-1] + "2"
sub_finals[1] = sub_finals[1][:-1] + "2"
# monosyllabic + disyllabic, eg: 纸/老虎
finals[0] = finals[0][:-1] + "2"
finals[1] = finals[1][:-1] + "2"
# monosyllabic + disyllabic, e.g. 纸/老虎
elif len(new_word_list[0]) == 1:
sub_finals[1] = sub_finals[1][:-1] + "2"
finals[1] = finals[1][:-1] + "2"
else:
sub_finals_list = [sub_finals[:len(new_word_list[0])], sub_finals[len(new_word_list[0]):]]
if len(sub_finals_list) == 2:
for i, sub in enumerate(sub_finals_list):
# eg: 所有/人
finals_list = [finals[:len(new_word_list[0])], finals[len(new_word_list[0]):]]
if len(finals_list) == 2:
for i, sub in enumerate(finals_list):
# e.g. 所有/人
if self._all_tone_three(sub) and len(sub) == 2:
sub_finals_list[i][0] = sub_finals_list[i][0][:-1] + "2"
# eg: 好/喜欢
elif i == 1 and not self._all_tone_three(sub) and sub_finals_list[i][0][-1] == "3" and \
sub_finals_list[0][-1][-1] == "3":
finals_list[i][0] = finals_list[i][0][:-1] + "2"
# e.g. 好/喜欢
elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \
finals_list[0][-1][-1] == "3":
sub_finals_list[0][-1] = sub_finals_list[0][-1][:-1] + "2"
sub_finals = sum(sub_finals_list, [])
finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
finals = sum(finals_list, [])
# split idiom into two words who's length is 2
elif len(word) == 4:
sub_finals_list = [sub_finals[:2], sub_finals[2:]]
sub_finals = []
for sub in sub_finals_list:
finals_list = [finals[:2], finals[2:]]
finals = []
for sub in finals_list:
if self._all_tone_three(sub):
sub[0] = sub[0][:-1] + "2"
sub_finals += sub
finals += sub
return sub_finals
return finals
def _all_tone_three(self, finals):
return all(x[-1] == "3" for x in finals)
# merge "不" and the word behind it
# if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
def _merge_bu(self, seg):
new_seg = []
last_word = ""
@ -176,8 +181,12 @@ class ModifiedTone():
seg = new_seg
return seg
# function 1: merge "一" and reduplication words in it's left and right,eg: "看","一","看" ->"看一看"
# function 2: merge single "一" and the word behind it
# function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
# function 2: merge single "一" and the word behind it
# if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
# e.g.
# input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
# output seg: [['听一听', 'v']]
def _merge_yi(self, seg):
new_seg = []
# function 1
@ -198,7 +207,6 @@ class ModifiedTone():
new_seg[-1][0] = new_seg[-1][0] + word
else:
new_seg.append([word, pos])
seg = new_seg
return seg
@ -228,8 +236,8 @@ class ModifiedTone():
return seg
def modified_tone(self, word, pos, finals):
finals = self._bu_tone(word, finals)
finals = self._yi_tone(word, finals)
finals = self._neural_tone(word, pos, finals)
finals = self._three_tone(word, finals)
finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals)
finals = self._three_sandhi(word, finals)
return finals

View File

@ -205,6 +205,7 @@ class FastSpeech2(nn.Layer):
attention_heads=aheads,
linear_units=dunits,
num_blocks=dlayers,
# in decoder, don't need layer before pos_enc_class (we use embedding here in encoder)
input_layer=None,
dropout_rate=transformer_dec_dropout_rate,
positional_dropout_rate=transformer_dec_positional_dropout_rate,
@ -286,7 +287,7 @@ class FastSpeech2(nn.Layer):
speech_lengths, modified if reduction_factor >1
"""
xs = paddle.to_tensor(text)
xs = text
ilens = text_lengths
ys, ds, ps, es = speech, durations, pitch, energy
olens = speech_lengths
@ -354,7 +355,7 @@ class FastSpeech2(nn.Layer):
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds)
# forward decoder
# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = paddle.to_tensor(
@ -415,7 +416,6 @@ class FastSpeech2(nn.Layer):
"""
x, y = text, speech
d, p, e = durations, pitch, energy
x = paddle.to_tensor(text)
# setup batch axis
ilens = paddle.to_tensor(