modified the name of vocoder

This commit is contained in:
lifuchen 2020-02-18 03:49:36 +00:00
parent 147f7be251
commit cb3cfd621b
2 changed files with 9 additions and 9 deletions

View File

@ -42,13 +42,13 @@ def synthesis(text_input, args):
with dg.guard(place):
with fluid.unique_name.guard():
model = TransformerTTS(cfg)
model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "nostop_token/transformer")))
model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "transformer")))
model.eval()
with fluid.unique_name.guard():
model_postnet = Vocoder(cfg, args.batch_size)
model_postnet.set_dict(load_checkpoint(str(args.postnet_step), os.path.join(args.checkpoint_path, "postnet")))
model_postnet.eval()
model_vocoder = Vocoder(cfg, args.batch_size)
model_vocoder.set_dict(load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder")))
model_vocoder.eval()
# init input
text = np.asarray(text_to_sequence(text_input))
text = fluid.layers.unsqueeze(dg.to_variable(text),[0])
@ -64,7 +64,7 @@ def synthesis(text_input, args):
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel)
mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1)
mag_pred = model_postnet(postnet_pred)
mag_pred = model_vocoder(postnet_pred)
_ljspeech_processor = audio.AudioProcessor(
sample_rate=cfg['audio']['sr'],

View File

@ -38,7 +38,7 @@ def main(args):
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
path = os.path.join(args.log_dir,'postnet')
path = os.path.join(args.log_dir,'vocoder')
writer = SummaryWriter(path) if local_rank == 0 else None
@ -51,7 +51,7 @@ def main(args):
if args.checkpoint_path is not None:
model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "postnet"))
model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder"))
model.set_dict(model_dict)
optimizer.set_dict(opti_dict)
global_step = args.vocoder_step
@ -92,7 +92,7 @@ def main(args):
if global_step % args.save_step == 0:
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
save_path = os.path.join(args.save_path,'postnet/%d' % global_step)
save_path = os.path.join(args.save_path,'vocoder/%d' % global_step)
dg.save_dygraph(model.state_dict(), save_path)
dg.save_dygraph(optimizer.state_dict(), save_path)
@ -100,7 +100,7 @@ def main(args):
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train postnet model")
parser = argparse.ArgumentParser(description="Train vocoder model")
add_config_options_to_parser(parser)
args = parser.parse_args()
# Print the whole config setting.