diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index dc80dc7..fb1bd2f 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -42,13 +42,13 @@ def synthesis(text_input, args): with dg.guard(place): with fluid.unique_name.guard(): model = TransformerTTS(cfg) - model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "nostop_token/transformer"))) + model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "transformer"))) model.eval() with fluid.unique_name.guard(): - model_postnet = Vocoder(cfg, args.batch_size) - model_postnet.set_dict(load_checkpoint(str(args.postnet_step), os.path.join(args.checkpoint_path, "postnet"))) - model_postnet.eval() + model_vocoder = Vocoder(cfg, args.batch_size) + model_vocoder.set_dict(load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder"))) + model_vocoder.eval() # init input text = np.asarray(text_to_sequence(text_input)) text = fluid.layers.unsqueeze(dg.to_variable(text),[0]) @@ -64,7 +64,7 @@ def synthesis(text_input, args): pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0]) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel) mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1) - mag_pred = model_postnet(postnet_pred) + mag_pred = model_vocoder(postnet_pred) _ljspeech_processor = audio.AudioProcessor( sample_rate=cfg['audio']['sr'], diff --git a/examples/transformer_tts/train_vocoder.py b/examples/transformer_tts/train_vocoder.py index cc32ca9..857fdf0 100644 --- a/examples/transformer_tts/train_vocoder.py +++ b/examples/transformer_tts/train_vocoder.py @@ -38,7 +38,7 @@ def main(args): if not os.path.exists(args.log_dir): os.mkdir(args.log_dir) - path = os.path.join(args.log_dir,'postnet') + path = os.path.join(args.log_dir,'vocoder') writer = SummaryWriter(path) if local_rank == 0 else None @@ -51,7 +51,7 @@ def main(args): if args.checkpoint_path is not None: - model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "postnet")) + model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder")) model.set_dict(model_dict) optimizer.set_dict(opti_dict) global_step = args.vocoder_step @@ -92,7 +92,7 @@ def main(args): if global_step % args.save_step == 0: if not os.path.exists(args.save_path): os.mkdir(args.save_path) - save_path = os.path.join(args.save_path,'postnet/%d' % global_step) + save_path = os.path.join(args.save_path,'vocoder/%d' % global_step) dg.save_dygraph(model.state_dict(), save_path) dg.save_dygraph(optimizer.state_dict(), save_path) @@ -100,7 +100,7 @@ def main(args): writer.close() if __name__ == '__main__': - parser = argparse.ArgumentParser(description="Train postnet model") + parser = argparse.ArgumentParser(description="Train vocoder model") add_config_options_to_parser(parser) args = parser.parse_args() # Print the whole config setting.