fix bugs
This commit is contained in:
parent
ce8fad5412
commit
bb5f445212
|
@ -114,9 +114,9 @@ def train(args, config):
|
||||||
loss.numpy()[0],
|
loss.numpy()[0],
|
||||||
causal_mel_loss.numpy()[0],
|
causal_mel_loss.numpy()[0],
|
||||||
non_causal_mel_loss.numpy()[0]))
|
non_causal_mel_loss.numpy()[0]))
|
||||||
writer.add_scalar("loss/causal_mel_loss", causal_mel_loss.numpy()[0], global_step=global_step)
|
writer.add_scalar("loss/causal_mel_loss", causal_mel_loss.numpy()[0], step=global_step)
|
||||||
writer.add_scalar("loss/non_causal_mel_loss", non_causal_mel_loss.numpy()[0], global_step=global_step)
|
writer.add_scalar("loss/non_causal_mel_loss", non_causal_mel_loss.numpy()[0], step=global_step)
|
||||||
writer.add_scalar("loss/loss", loss.numpy()[0], global_step=global_step)
|
writer.add_scalar("loss/loss", loss.numpy()[0], step=global_step)
|
||||||
|
|
||||||
if global_step % config["report_interval"] == 0:
|
if global_step % config["report_interval"] == 0:
|
||||||
text_length = int(text_lengths.numpy()[0])
|
text_length = int(text_lengths.numpy()[0])
|
||||||
|
@ -124,37 +124,37 @@ def train(args, config):
|
||||||
|
|
||||||
tag = "train_mel/ground-truth"
|
tag = "train_mel/ground-truth"
|
||||||
img = cm.viridis(normalize(mels.numpy()[0, :num_frame].T))
|
img = cm.viridis(normalize(mels.numpy()[0, :num_frame].T))
|
||||||
writer.add_image(tag, img, global_step=global_step, dataformats="HWC")
|
writer.add_image(tag, img, step=global_step)
|
||||||
|
|
||||||
tag = "train_mel/decoded"
|
tag = "train_mel/decoded"
|
||||||
img = cm.viridis(normalize(decoded.numpy()[0, :num_frame].T))
|
img = cm.viridis(normalize(decoded.numpy()[0, :num_frame].T))
|
||||||
writer.add_image(tag, img, global_step=global_step, dataformats="HWC")
|
writer.add_image(tag, img, step=global_step)
|
||||||
|
|
||||||
tag = "train_mel/refined"
|
tag = "train_mel/refined"
|
||||||
img = cm.viridis(normalize(refined.numpy()[0, :num_frame].T))
|
img = cm.viridis(normalize(refined.numpy()[0, :num_frame].T))
|
||||||
writer.add_image(tag, img, global_step=global_step, dataformats="HWC")
|
writer.add_image(tag, img, step=global_step)
|
||||||
|
|
||||||
vocoder = WaveflowVocoder()
|
vocoder = WaveflowVocoder()
|
||||||
vocoder.model.eval()
|
vocoder.model.eval()
|
||||||
|
|
||||||
tag = "train_audio/ground-truth-waveflow"
|
tag = "train_audio/ground-truth-waveflow"
|
||||||
wav = vocoder(F.transpose(mels[0:1, :num_frame, :], (0, 2, 1)))
|
wav = vocoder(F.transpose(mels[0:1, :num_frame, :], (0, 2, 1)))
|
||||||
writer.add_audio(tag, wav.numpy()[0], global_step=global_step, sample_rate=22050)
|
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
|
||||||
|
|
||||||
tag = "train_audio/decoded-waveflow"
|
tag = "train_audio/decoded-waveflow"
|
||||||
wav = vocoder(F.transpose(decoded[0:1, :num_frame, :], (0, 2, 1)))
|
wav = vocoder(F.transpose(decoded[0:1, :num_frame, :], (0, 2, 1)))
|
||||||
writer.add_audio(tag, wav.numpy()[0], global_step=global_step, sample_rate=22050)
|
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
|
||||||
|
|
||||||
tag = "train_audio/refined-waveflow"
|
tag = "train_audio/refined-waveflow"
|
||||||
wav = vocoder(F.transpose(refined[0:1, :num_frame, :], (0, 2, 1)))
|
wav = vocoder(F.transpose(refined[0:1, :num_frame, :], (0, 2, 1)))
|
||||||
writer.add_audio(tag, wav.numpy()[0], global_step=global_step, sample_rate=22050)
|
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
|
||||||
|
|
||||||
attentions_np = attentions.numpy()
|
attentions_np = attentions.numpy()
|
||||||
attentions_np = attentions_np[:, 0, :num_frame // 4 , :text_length]
|
attentions_np = attentions_np[:, 0, :num_frame // 4 , :text_length]
|
||||||
for i, attention_layer in enumerate(np.rot90(attentions_np, axes=(1,2))):
|
for i, attention_layer in enumerate(np.rot90(attentions_np, axes=(1,2))):
|
||||||
tag = "train_attention/layer_{}".format(i)
|
tag = "train_attention/layer_{}".format(i)
|
||||||
img = cm.viridis(normalize(attention_layer))
|
img = cm.viridis(normalize(attention_layer))
|
||||||
writer.add_image(tag, img, global_step=global_step, dataformats="HWC")
|
writer.add_image(tag, img, step=global_step, dataformats="HWC")
|
||||||
|
|
||||||
if global_step % config["save_interval"] == 0:
|
if global_step % config["save_interval"] == 0:
|
||||||
save_parameters(writer.logdir, global_step, model, optim)
|
save_parameters(writer.logdir, global_step, model, optim)
|
||||||
|
|
Loading…
Reference in New Issue