Merge branch 'develop' into baker

This commit is contained in:
iclementine 2021-04-02 11:23:21 +08:00
commit 15b205d6e0
4 changed files with 19 additions and 77 deletions

View File

@ -86,15 +86,18 @@ class Experiment(ExperimentBase):
valid_losses[k].append(float(v))
attention_weights = outputs["alignments"]
display.add_attention_plots(self.visualizer,
f"valid_sentence_{i}_alignments",
attention_weights[0], self.iteration)
display.add_spectrogram_plots(
self.visualizer, f"valid_sentence_{i}_target_spectrogram",
mels[0], self.iteration)
display.add_spectrogram_plots(
self.visualizer, f"valid_sentence_{i}_predicted_spectrogram",
outputs['mel_outputs_postnet'][0], self.iteration)
self.visualizer.add_figure(
f"valid_sentence_{i}_alignments",
display.plot_alignment(attention_weights[0].numpy()),
self.iteration)
self.visualizer.add_figure(
f"valid_sentence_{i}_target_spectrogram",
display.plot_spectrogram(mels[0].numpy().T),
self.iteration)
self.visualizer.add_figure(
f"valid_sentence_{i}_predicted_spectrogram",
display.plot_spectrogram(outputs['mel_outputs_postnet'][0].numpy().T),
self.iteration)
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}

View File

@ -19,7 +19,7 @@ import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from tensorboardX import SummaryWriter
from visualdl import LogWriter
from collections import defaultdict
import time

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import matplotlib
import librosa
@ -21,44 +22,11 @@ import matplotlib.pylab as plt
from matplotlib import cm, pyplot
__all__ = [
"pack_attention_images",
"add_attention_plots",
"plot_alignment",
"min_max_normalize",
"add_spectrogram_plots",
"plot_spectrogram",
]
def pack_attention_images(attention_weights, rotate=False):
# add a box
attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)],
mode="constant",
constant_values=1.)
if rotate:
attention_weights = np.rot90(attention_weights, axes=(1, 2))
n, h, w = attention_weights.shape
ratio = h / w
if ratio < 1:
cols = max(int(np.sqrt(n / ratio)), 1)
rows = int(np.ceil(n / cols))
else:
rows = max(int(np.sqrt(n / ratio)), 1)
cols = int(np.ceil(n / rows))
extras = rows * cols - n
#print(rows, cols, extras)
total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0)
total = np.reshape(total, [rows, cols, h, w])
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
return img
def save_figure_to_numpy(fig):
# save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
return data
def plot_alignment(alignment, title=None):
fig, ax = plt.subplots(figsize=(6, 4))
@ -73,42 +41,13 @@ def plot_alignment(alignment, title=None):
plt.tight_layout()
return fig
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def add_attention_plots(writer, tag, attention_weights, global_step):
img = plot_alignment(attention_weights.numpy().T)
writer.add_image(tag, img, global_step, dataformats="HWC")
def add_multi_attention_plots(writer, tag, attention_weights, global_step):
attns = [attn[0].numpy() for attn in attention_weights]
for i, attn in enumerate(attns):
img = pack_attention_images(attn)
writer.add_image(
f"{tag}/{i}",
cm.plasma(img),
global_step=global_step,
dataformats="HWC")
def add_spectrogram_plots(writer, tag, spec, global_step):
spec = spec.numpy().T
def plot_spectrogram(spec):
# spec: [C, T] librosa convention
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
writer.add_image(tag, data, global_step, dataformats="HWC")
def min_max_normalize(v):
return (v - v.min()) / (v.max() - v.min())
return fig

View File

@ -9,7 +9,7 @@ import subprocess
import platform
COPYRIGHT = '''
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.