modified the name of vocoder
This commit is contained in:
parent
147f7be251
commit
cb3cfd621b
|
@ -42,13 +42,13 @@ def synthesis(text_input, args):
|
|||
with dg.guard(place):
|
||||
with fluid.unique_name.guard():
|
||||
model = TransformerTTS(cfg)
|
||||
model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "nostop_token/transformer")))
|
||||
model.set_dict(load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "transformer")))
|
||||
model.eval()
|
||||
|
||||
with fluid.unique_name.guard():
|
||||
model_postnet = Vocoder(cfg, args.batch_size)
|
||||
model_postnet.set_dict(load_checkpoint(str(args.postnet_step), os.path.join(args.checkpoint_path, "postnet")))
|
||||
model_postnet.eval()
|
||||
model_vocoder = Vocoder(cfg, args.batch_size)
|
||||
model_vocoder.set_dict(load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder")))
|
||||
model_vocoder.eval()
|
||||
# init input
|
||||
text = np.asarray(text_to_sequence(text_input))
|
||||
text = fluid.layers.unsqueeze(dg.to_variable(text),[0])
|
||||
|
@ -64,7 +64,7 @@ def synthesis(text_input, args):
|
|||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel),[0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(text, mel_input, pos_text, pos_mel)
|
||||
mel_input = fluid.layers.concat([mel_input, postnet_pred[:,-1:,:]], axis=1)
|
||||
mag_pred = model_postnet(postnet_pred)
|
||||
mag_pred = model_vocoder(postnet_pred)
|
||||
|
||||
_ljspeech_processor = audio.AudioProcessor(
|
||||
sample_rate=cfg['audio']['sr'],
|
||||
|
|
|
@ -38,7 +38,7 @@ def main(args):
|
|||
|
||||
if not os.path.exists(args.log_dir):
|
||||
os.mkdir(args.log_dir)
|
||||
path = os.path.join(args.log_dir,'postnet')
|
||||
path = os.path.join(args.log_dir,'vocoder')
|
||||
|
||||
writer = SummaryWriter(path) if local_rank == 0 else None
|
||||
|
||||
|
@ -51,7 +51,7 @@ def main(args):
|
|||
|
||||
|
||||
if args.checkpoint_path is not None:
|
||||
model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "postnet"))
|
||||
model_dict, opti_dict = load_checkpoint(str(args.vocoder_step), os.path.join(args.checkpoint_path, "vocoder"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = args.vocoder_step
|
||||
|
@ -92,7 +92,7 @@ def main(args):
|
|||
if global_step % args.save_step == 0:
|
||||
if not os.path.exists(args.save_path):
|
||||
os.mkdir(args.save_path)
|
||||
save_path = os.path.join(args.save_path,'postnet/%d' % global_step)
|
||||
save_path = os.path.join(args.save_path,'vocoder/%d' % global_step)
|
||||
dg.save_dygraph(model.state_dict(), save_path)
|
||||
dg.save_dygraph(optimizer.state_dict(), save_path)
|
||||
|
||||
|
@ -100,7 +100,7 @@ def main(args):
|
|||
writer.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Train postnet model")
|
||||
parser = argparse.ArgumentParser(description="Train vocoder model")
|
||||
add_config_options_to_parser(parser)
|
||||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
|
|
Loading…
Reference in New Issue