Merge branch 'develop' into baker
This commit is contained in:
commit
15b205d6e0
|
@ -86,15 +86,18 @@ class Experiment(ExperimentBase):
|
|||
valid_losses[k].append(float(v))
|
||||
|
||||
attention_weights = outputs["alignments"]
|
||||
display.add_attention_plots(self.visualizer,
|
||||
f"valid_sentence_{i}_alignments",
|
||||
attention_weights[0], self.iteration)
|
||||
display.add_spectrogram_plots(
|
||||
self.visualizer, f"valid_sentence_{i}_target_spectrogram",
|
||||
mels[0], self.iteration)
|
||||
display.add_spectrogram_plots(
|
||||
self.visualizer, f"valid_sentence_{i}_predicted_spectrogram",
|
||||
outputs['mel_outputs_postnet'][0], self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_alignments",
|
||||
display.plot_alignment(attention_weights[0].numpy()),
|
||||
self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_target_spectrogram",
|
||||
display.plot_spectrogram(mels[0].numpy().T),
|
||||
self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_predicted_spectrogram",
|
||||
display.plot_spectrogram(outputs['mel_outputs_postnet'][0].numpy().T),
|
||||
self.iteration)
|
||||
|
||||
# write visual log
|
||||
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from visualdl import LogWriter
|
||||
from collections import defaultdict
|
||||
import time
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import librosa
|
||||
|
@ -21,44 +22,11 @@ import matplotlib.pylab as plt
|
|||
from matplotlib import cm, pyplot
|
||||
|
||||
__all__ = [
|
||||
"pack_attention_images",
|
||||
"add_attention_plots",
|
||||
"plot_alignment",
|
||||
"min_max_normalize",
|
||||
"add_spectrogram_plots",
|
||||
"plot_spectrogram",
|
||||
]
|
||||
|
||||
|
||||
def pack_attention_images(attention_weights, rotate=False):
|
||||
# add a box
|
||||
attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)],
|
||||
mode="constant",
|
||||
constant_values=1.)
|
||||
if rotate:
|
||||
attention_weights = np.rot90(attention_weights, axes=(1, 2))
|
||||
n, h, w = attention_weights.shape
|
||||
|
||||
ratio = h / w
|
||||
if ratio < 1:
|
||||
cols = max(int(np.sqrt(n / ratio)), 1)
|
||||
rows = int(np.ceil(n / cols))
|
||||
else:
|
||||
rows = max(int(np.sqrt(n / ratio)), 1)
|
||||
cols = int(np.ceil(n / rows))
|
||||
extras = rows * cols - n
|
||||
#print(rows, cols, extras)
|
||||
total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0)
|
||||
total = np.reshape(total, [rows, cols, h, w])
|
||||
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
||||
return img
|
||||
|
||||
|
||||
def save_figure_to_numpy(fig):
|
||||
# save it to a numpy array.
|
||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
return data
|
||||
|
||||
|
||||
def plot_alignment(alignment, title=None):
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
|
@ -73,42 +41,13 @@ def plot_alignment(alignment, title=None):
|
|||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
|
||||
def add_attention_plots(writer, tag, attention_weights, global_step):
|
||||
img = plot_alignment(attention_weights.numpy().T)
|
||||
writer.add_image(tag, img, global_step, dataformats="HWC")
|
||||
|
||||
|
||||
def add_multi_attention_plots(writer, tag, attention_weights, global_step):
|
||||
attns = [attn[0].numpy() for attn in attention_weights]
|
||||
for i, attn in enumerate(attns):
|
||||
img = pack_attention_images(attn)
|
||||
writer.add_image(
|
||||
f"{tag}/{i}",
|
||||
cm.plasma(img),
|
||||
global_step=global_step,
|
||||
dataformats="HWC")
|
||||
|
||||
|
||||
def add_spectrogram_plots(writer, tag, spec, global_step):
|
||||
spec = spec.numpy().T
|
||||
def plot_spectrogram(spec):
|
||||
# spec: [C, T] librosa convention
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close()
|
||||
writer.add_image(tag, data, global_step, dataformats="HWC")
|
||||
|
||||
|
||||
def min_max_normalize(v):
|
||||
return (v - v.min()) / (v.max() - v.min())
|
||||
return fig
|
||||
|
|
|
@ -9,7 +9,7 @@ import subprocess
|
|||
import platform
|
||||
|
||||
COPYRIGHT = '''
|
||||
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
Loading…
Reference in New Issue