Merge pull request #35 from ShenYuhan/fix_bug

fix bugs of vdl
This commit is contained in:
Li Fuchen 2020-08-25 17:41:39 +08:00 committed by GitHub
commit 1db01ccc90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 10 deletions

View File

@ -114,9 +114,9 @@ def train(args, config):
loss.numpy()[0], loss.numpy()[0],
causal_mel_loss.numpy()[0], causal_mel_loss.numpy()[0],
non_causal_mel_loss.numpy()[0])) non_causal_mel_loss.numpy()[0]))
writer.add_scalar("loss/causal_mel_loss", causal_mel_loss.numpy()[0], global_step=global_step) writer.add_scalar("loss/causal_mel_loss", causal_mel_loss.numpy()[0], step=global_step)
writer.add_scalar("loss/non_causal_mel_loss", non_causal_mel_loss.numpy()[0], global_step=global_step) writer.add_scalar("loss/non_causal_mel_loss", non_causal_mel_loss.numpy()[0], step=global_step)
writer.add_scalar("loss/loss", loss.numpy()[0], global_step=global_step) writer.add_scalar("loss/loss", loss.numpy()[0], step=global_step)
if global_step % config["report_interval"] == 0: if global_step % config["report_interval"] == 0:
text_length = int(text_lengths.numpy()[0]) text_length = int(text_lengths.numpy()[0])
@ -124,37 +124,37 @@ def train(args, config):
tag = "train_mel/ground-truth" tag = "train_mel/ground-truth"
img = cm.viridis(normalize(mels.numpy()[0, :num_frame].T)) img = cm.viridis(normalize(mels.numpy()[0, :num_frame].T))
writer.add_image(tag, img, global_step=global_step, dataformats="HWC") writer.add_image(tag, img, step=global_step)
tag = "train_mel/decoded" tag = "train_mel/decoded"
img = cm.viridis(normalize(decoded.numpy()[0, :num_frame].T)) img = cm.viridis(normalize(decoded.numpy()[0, :num_frame].T))
writer.add_image(tag, img, global_step=global_step, dataformats="HWC") writer.add_image(tag, img, step=global_step)
tag = "train_mel/refined" tag = "train_mel/refined"
img = cm.viridis(normalize(refined.numpy()[0, :num_frame].T)) img = cm.viridis(normalize(refined.numpy()[0, :num_frame].T))
writer.add_image(tag, img, global_step=global_step, dataformats="HWC") writer.add_image(tag, img, step=global_step)
vocoder = WaveflowVocoder() vocoder = WaveflowVocoder()
vocoder.model.eval() vocoder.model.eval()
tag = "train_audio/ground-truth-waveflow" tag = "train_audio/ground-truth-waveflow"
wav = vocoder(F.transpose(mels[0:1, :num_frame, :], (0, 2, 1))) wav = vocoder(F.transpose(mels[0:1, :num_frame, :], (0, 2, 1)))
writer.add_audio(tag, wav.numpy()[0], global_step=global_step, sample_rate=22050) writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
tag = "train_audio/decoded-waveflow" tag = "train_audio/decoded-waveflow"
wav = vocoder(F.transpose(decoded[0:1, :num_frame, :], (0, 2, 1))) wav = vocoder(F.transpose(decoded[0:1, :num_frame, :], (0, 2, 1)))
writer.add_audio(tag, wav.numpy()[0], global_step=global_step, sample_rate=22050) writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
tag = "train_audio/refined-waveflow" tag = "train_audio/refined-waveflow"
wav = vocoder(F.transpose(refined[0:1, :num_frame, :], (0, 2, 1))) wav = vocoder(F.transpose(refined[0:1, :num_frame, :], (0, 2, 1)))
writer.add_audio(tag, wav.numpy()[0], global_step=global_step, sample_rate=22050) writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
attentions_np = attentions.numpy() attentions_np = attentions.numpy()
attentions_np = attentions_np[:, 0, :num_frame // 4 , :text_length] attentions_np = attentions_np[:, 0, :num_frame // 4 , :text_length]
for i, attention_layer in enumerate(np.rot90(attentions_np, axes=(1,2))): for i, attention_layer in enumerate(np.rot90(attentions_np, axes=(1,2))):
tag = "train_attention/layer_{}".format(i) tag = "train_attention/layer_{}".format(i)
img = cm.viridis(normalize(attention_layer)) img = cm.viridis(normalize(attention_layer))
writer.add_image(tag, img, global_step=global_step, dataformats="HWC") writer.add_image(tag, img, step=global_step, dataformats="HWC")
if global_step % config["save_interval"] == 0: if global_step % config["save_interval"] == 0:
save_parameters(writer.logdir, global_step, model, optim) save_parameters(writer.logdir, global_step, model, optim)