add plot alignment function
This commit is contained in:
parent
4de58f4a99
commit
99fdd10b5d
|
@ -29,7 +29,7 @@ class DecoderPreNet(nn.Layer):
|
||||||
d_input: int,
|
d_input: int,
|
||||||
d_hidden: int,
|
d_hidden: int,
|
||||||
d_output: int,
|
d_output: int,
|
||||||
dropout_rate: int=0.2):
|
dropout_rate: float=0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
|
@ -49,7 +49,7 @@ class DecoderPostNet(nn.Layer):
|
||||||
kernel_size: int=5,
|
kernel_size: int=5,
|
||||||
padding: int=0,
|
padding: int=0,
|
||||||
num_layers: int=5,
|
num_layers: int=5,
|
||||||
dropout=0.1):
|
dropout: float=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
|
@ -1,14 +1,32 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pylab as plt
|
||||||
from matplotlib import cm, pyplot
|
from matplotlib import cm, pyplot
|
||||||
|
|
||||||
__all__ = ["pack_attention_images", "add_attention_plots", "min_max_normalize"]
|
__all__ = [
|
||||||
|
"pack_attention_images", "add_attention_plots", "plot_alignment",
|
||||||
|
"min_max_normalize"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def pack_attention_images(attention_weights, rotate=False):
|
def pack_attention_images(attention_weights, rotate=False):
|
||||||
# add a box
|
# add a box
|
||||||
attention_weights = np.pad(attention_weights,
|
attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)],
|
||||||
[(0, 0), (1, 1), (1, 1)],
|
|
||||||
mode="constant",
|
mode="constant",
|
||||||
constant_values=1.)
|
constant_values=1.)
|
||||||
if rotate:
|
if rotate:
|
||||||
|
@ -29,14 +47,41 @@ def pack_attention_images(attention_weights, rotate=False):
|
||||||
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def plot_alignment(alignment, title=None):
|
||||||
|
fig, ax = plt.subplots(figsize=(6, 4))
|
||||||
|
im = ax.imshow(
|
||||||
|
alignment, aspect='auto', origin='lower', interpolation='none')
|
||||||
|
fig.colorbar(im, ax=ax)
|
||||||
|
xlabel = 'Decoder timestep'
|
||||||
|
if title is not None:
|
||||||
|
xlabel += '\n\n' + title
|
||||||
|
plt.xlabel(xlabel)
|
||||||
|
plt.ylabel('Encoder timestep')
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||||
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||||
|
plt.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def add_attention_plots(writer, tag, attention_weights, global_step):
|
def add_attention_plots(writer, tag, attention_weights, global_step):
|
||||||
|
img = plot_alignment(attention_weights.numpy().T)
|
||||||
|
writer.add_image(tag, img, global_step, dataformats="HWC")
|
||||||
|
|
||||||
|
|
||||||
|
def add_multi_attention_plots(writer, tag, attention_weights, global_step):
|
||||||
attns = [attn[0].numpy() for attn in attention_weights]
|
attns = [attn[0].numpy() for attn in attention_weights]
|
||||||
for i, attn in enumerate(attns):
|
for i, attn in enumerate(attns):
|
||||||
img = pack_attention_images(attn)
|
img = pack_attention_images(attn)
|
||||||
writer.add_image(f"{tag}/{i}",
|
writer.add_image(
|
||||||
cm.plasma(img),
|
f"{tag}/{i}",
|
||||||
global_step=global_step,
|
cm.plasma(img),
|
||||||
dataformats="HWC")
|
global_step=global_step,
|
||||||
|
dataformats="HWC")
|
||||||
|
|
||||||
|
|
||||||
def min_max_normalize(v):
|
def min_max_normalize(v):
|
||||||
return (v - v.min()) / (v.max() - v.min())
|
return (v - v.min()) / (v.max() - v.min())
|
||||||
|
|
Loading…
Reference in New Issue