diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py index bd94789..b552a94 100644 --- a/parakeet/utils/display.py +++ b/parakeet/utils/display.py @@ -19,8 +19,11 @@ import matplotlib.pylab as plt from matplotlib import cm, pyplot __all__ = [ - "pack_attention_images", "add_attention_plots", "plot_alignment", - "min_max_normalize" + "pack_attention_images", + "add_attention_plots", + "plot_alignment", + "min_max_normalize", + "add_spectrogram_plots", ] @@ -48,6 +51,13 @@ def pack_attention_images(attention_weights, rotate=False): return img +def save_figure_to_numpy(fig): + # save it to a numpy array. + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) + return data + + def plot_alignment(alignment, title=None): fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow( @@ -61,8 +71,7 @@ def plot_alignment(alignment, title=None): plt.tight_layout() fig.canvas.draw() - data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) + data = save_figure_to_numpy(fig) plt.close() return data @@ -83,5 +92,20 @@ def add_multi_attention_plots(writer, tag, attention_weights, global_step): dataformats="HWC") +def add_spectrogram_plots(writer, tag, spec, global_step): + spec = spec.numpy() + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + writer.add_image(tag, data, global_step, dataformats="HWC") + + def min_max_normalize(v): return (v - v.min()) / (v.max() - v.min())