From 99fdd10b5dfbc6812132a3f821b3712376b43cd8 Mon Sep 17 00:00:00 2001
From: lfchener <lfchener@outlook.com>
Date: Fri, 11 Dec 2020 11:55:45 +0000
Subject: [PATCH] add plot alignment function

---
 parakeet/models/tacotron2.py |  4 +--
 parakeet/utils/display.py    | 59 +++++++++++++++++++++++++++++++-----
 2 files changed, 54 insertions(+), 9 deletions(-)

diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py
index 912cbab..e2eef3e 100644
--- a/parakeet/models/tacotron2.py
+++ b/parakeet/models/tacotron2.py
@@ -29,7 +29,7 @@ class DecoderPreNet(nn.Layer):
                  d_input: int,
                  d_hidden: int,
                  d_output: int,
-                 dropout_rate: int=0.2):
+                 dropout_rate: float=0.2):
         super().__init__()
 
         self.dropout_rate = dropout_rate
@@ -49,7 +49,7 @@ class DecoderPostNet(nn.Layer):
                  kernel_size: int=5,
                  padding: int=0,
                  num_layers: int=5,
-                 dropout=0.1):
+                 dropout: float=0.1):
         super().__init__()
         self.dropout = dropout
         self.num_layers = num_layers
diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py
index 1d01f97..bd94789 100644
--- a/parakeet/utils/display.py
+++ b/parakeet/utils/display.py
@@ -1,14 +1,32 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import numpy as np
 import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pylab as plt
 from matplotlib import cm, pyplot
 
-__all__ = ["pack_attention_images", "add_attention_plots", "min_max_normalize"]
+__all__ = [
+    "pack_attention_images", "add_attention_plots", "plot_alignment",
+    "min_max_normalize"
+]
 
 
 def pack_attention_images(attention_weights, rotate=False):
     # add a box
-    attention_weights = np.pad(attention_weights, 
-                               [(0, 0), (1, 1), (1, 1)], 
+    attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)],
                                mode="constant",
                                constant_values=1.)
     if rotate:
@@ -29,14 +47,41 @@ def pack_attention_images(attention_weights, rotate=False):
     img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
     return img
 
+
+def plot_alignment(alignment, title=None):
+    fig, ax = plt.subplots(figsize=(6, 4))
+    im = ax.imshow(
+        alignment, aspect='auto', origin='lower', interpolation='none')
+    fig.colorbar(im, ax=ax)
+    xlabel = 'Decoder timestep'
+    if title is not None:
+        xlabel += '\n\n' + title
+    plt.xlabel(xlabel)
+    plt.ylabel('Encoder timestep')
+    plt.tight_layout()
+
+    fig.canvas.draw()
+    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
+    plt.close()
+    return data
+
+
 def add_attention_plots(writer, tag, attention_weights, global_step):
+    img = plot_alignment(attention_weights.numpy().T)
+    writer.add_image(tag, img, global_step, dataformats="HWC")
+
+
+def add_multi_attention_plots(writer, tag, attention_weights, global_step):
     attns = [attn[0].numpy() for attn in attention_weights]
     for i, attn in enumerate(attns):
         img = pack_attention_images(attn)
-        writer.add_image(f"{tag}/{i}", 
-                        cm.plasma(img),
-                        global_step=global_step,
-                        dataformats="HWC")
+        writer.add_image(
+            f"{tag}/{i}",
+            cm.plasma(img),
+            global_step=global_step,
+            dataformats="HWC")
+
 
 def min_max_normalize(v):
     return (v - v.min()) / (v.max() - v.min())