fix merge_sentences bug

This commit is contained in:
TianYuan 2021-08-16 11:58:36 +00:00
parent 48c65f4ab5
commit 9ca5ce0128
3 changed files with 44 additions and 20 deletions

View File

@ -48,9 +48,7 @@ class Frontend():
tone_ids = [self.vocab_tones[item] for item in tones]
return np.array(tone_ids, np.int64)
def get_input_ids(self, sentence, get_tone_ids=False):
phonemes = self.frontend.get_phonemes(sentence)
result = {}
def _get_phone_tone(self, phonemes, get_tone_ids=False):
phones = []
tones = []
if get_tone_ids and self.vocab_tones:
@ -76,12 +74,7 @@ class Frontend():
else:
phones.append(full_phone)
tones.append('0')
tone_ids = self._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids)
result["tone_ids"] = tone_ids
else:
phones = []
for phone in phonemes:
# if the merged erhua not in the vocab
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2']
@ -92,8 +85,30 @@ class Frontend():
phones.append("er2")
else:
phones.append(phone)
return phones, tones
def get_input_ids(self, sentence, merge_sentences=True,
get_tone_ids=False):
phonemes = self.frontend.get_phonemes(
sentence, merge_sentences=merge_sentences)
result = {}
phones = []
tones = []
temp_phone_ids = []
temp_tone_ids = []
for part_phonemes in phonemes:
phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids)
if tones:
tone_ids = self._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids)
if phones:
phone_ids = self._p2id(phones)
phone_ids = paddle.to_tensor(phone_ids)
result["phone_ids"] = phone_ids
temp_phone_ids.append(phone_ids)
if temp_tone_ids:
result["tone_ids"] = temp_tone_ids
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result

View File

@ -72,19 +72,25 @@ def evaluate(args, fastspeech2_config, pwg_config):
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
fastspeech2_inferencce = FastSpeech2Inference(fastspeech2_normalizer,
model)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
pwg_inference = PWGInference(pwg_normalizer, vocoder)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for utt_id, sentence in sentences:
input_ids = frontend.get_input_ids(sentence)
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"]
flags = 0
for part_phone_ids in phone_ids:
with paddle.no_grad():
mel = fastspeech2_inferencce(phone_ids)
wav = pwg_inference(mel)
mel = fastspeech2_inference(part_phone_ids)
temp_wav = pwg_inference(mel)
if flags == 0:
wav = temp_wav
flags = 1
else:
wav = paddle.concat([wav, temp_wav])
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),

View File

@ -116,7 +116,9 @@ class Frontend():
phones.append('sp')
phones_list.append(phones)
if merge_sentences:
phones_list = sum(phones_list, [])
merge_list = sum(phones_list, [])
phones_list = []
phones_list.append(merge_list)
return phones_list
def _merge_erhua(self, initials, finals, word, pos):
@ -136,7 +138,8 @@ class Frontend():
new_initials.append(initials[i])
return new_initials, new_finals
def get_phonemes(self, sentence, with_erhua=True):
def get_phonemes(self, sentence, merge_sentences=True, with_erhua=True):
sentences = self.text_normalizer.normalize(sentence)
phonemes = self._g2p(sentences, with_erhua=with_erhua)
phonemes = self._g2p(
sentences, merge_sentences=merge_sentences, with_erhua=with_erhua)
return phonemes