add plot alignment function

This commit is contained in:
lfchener 2020-12-11 11:55:45 +00:00
parent 4de58f4a99
commit 99fdd10b5d
2 changed files with 54 additions and 9 deletions

View File

@ -29,7 +29,7 @@ class DecoderPreNet(nn.Layer):
d_input: int, d_input: int,
d_hidden: int, d_hidden: int,
d_output: int, d_output: int,
dropout_rate: int=0.2): dropout_rate: float=0.2):
super().__init__() super().__init__()
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
@ -49,7 +49,7 @@ class DecoderPostNet(nn.Layer):
kernel_size: int=5, kernel_size: int=5,
padding: int=0, padding: int=0,
num_layers: int=5, num_layers: int=5,
dropout=0.1): dropout: float=0.1):
super().__init__() super().__init__()
self.dropout = dropout self.dropout = dropout
self.num_layers = num_layers self.num_layers = num_layers

View File

@ -1,14 +1,32 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import matplotlib import matplotlib
matplotlib.use("Agg")
import matplotlib.pylab as plt
from matplotlib import cm, pyplot from matplotlib import cm, pyplot
__all__ = ["pack_attention_images", "add_attention_plots", "min_max_normalize"] __all__ = [
"pack_attention_images", "add_attention_plots", "plot_alignment",
"min_max_normalize"
]
def pack_attention_images(attention_weights, rotate=False): def pack_attention_images(attention_weights, rotate=False):
# add a box # add a box
attention_weights = np.pad(attention_weights, attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)],
[(0, 0), (1, 1), (1, 1)],
mode="constant", mode="constant",
constant_values=1.) constant_values=1.)
if rotate: if rotate:
@ -29,14 +47,41 @@ def pack_attention_images(attention_weights, rotate=False):
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)]) img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
return img return img
def plot_alignment(alignment, title=None):
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if title is not None:
xlabel += '\n\n' + title
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
plt.close()
return data
def add_attention_plots(writer, tag, attention_weights, global_step): def add_attention_plots(writer, tag, attention_weights, global_step):
img = plot_alignment(attention_weights.numpy().T)
writer.add_image(tag, img, global_step, dataformats="HWC")
def add_multi_attention_plots(writer, tag, attention_weights, global_step):
attns = [attn[0].numpy() for attn in attention_weights] attns = [attn[0].numpy() for attn in attention_weights]
for i, attn in enumerate(attns): for i, attn in enumerate(attns):
img = pack_attention_images(attn) img = pack_attention_images(attn)
writer.add_image(f"{tag}/{i}", writer.add_image(
cm.plasma(img), f"{tag}/{i}",
global_step=global_step, cm.plasma(img),
dataformats="HWC") global_step=global_step,
dataformats="HWC")
def min_max_normalize(v): def min_max_normalize(v):
return (v - v.min()) / (v.max() - v.min()) return (v - v.min()) / (v.max() - v.min())