From 99fdd10b5dfbc6812132a3f821b3712376b43cd8 Mon Sep 17 00:00:00 2001 From: lfchener Date: Fri, 11 Dec 2020 11:55:45 +0000 Subject: [PATCH] add plot alignment function --- parakeet/models/tacotron2.py | 4 +-- parakeet/utils/display.py | 59 +++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index 912cbab..e2eef3e 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -29,7 +29,7 @@ class DecoderPreNet(nn.Layer): d_input: int, d_hidden: int, d_output: int, - dropout_rate: int=0.2): + dropout_rate: float=0.2): super().__init__() self.dropout_rate = dropout_rate @@ -49,7 +49,7 @@ class DecoderPostNet(nn.Layer): kernel_size: int=5, padding: int=0, num_layers: int=5, - dropout=0.1): + dropout: float=0.1): super().__init__() self.dropout = dropout self.num_layers = num_layers diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py index 1d01f97..bd94789 100644 --- a/parakeet/utils/display.py +++ b/parakeet/utils/display.py @@ -1,14 +1,32 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import matplotlib +matplotlib.use("Agg") +import matplotlib.pylab as plt from matplotlib import cm, pyplot -__all__ = ["pack_attention_images", "add_attention_plots", "min_max_normalize"] +__all__ = [ + "pack_attention_images", "add_attention_plots", "plot_alignment", + "min_max_normalize" +] def pack_attention_images(attention_weights, rotate=False): # add a box - attention_weights = np.pad(attention_weights, - [(0, 0), (1, 1), (1, 1)], + attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)], mode="constant", constant_values=1.) if rotate: @@ -29,14 +47,41 @@ def pack_attention_images(attention_weights, rotate=False): img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)]) return img + +def plot_alignment(alignment, title=None): + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment, aspect='auto', origin='lower', interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if title is not None: + xlabel += '\n\n' + title + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, )) + plt.close() + return data + + def add_attention_plots(writer, tag, attention_weights, global_step): + img = plot_alignment(attention_weights.numpy().T) + writer.add_image(tag, img, global_step, dataformats="HWC") + + +def add_multi_attention_plots(writer, tag, attention_weights, global_step): attns = [attn[0].numpy() for attn in attention_weights] for i, attn in enumerate(attns): img = pack_attention_images(attn) - writer.add_image(f"{tag}/{i}", - cm.plasma(img), - global_step=global_step, - dataformats="HWC") + writer.add_image( + f"{tag}/{i}", + cm.plasma(img), + global_step=global_step, + dataformats="HWC") + def min_max_normalize(v): return (v - v.min()) / (v.max() - v.min())