ParakeetEricRoss/parakeet/utils/display.py

28 lines
958 B
Python
Raw Normal View History

2020-10-30 21:36:11 +08:00
import numpy as np
import matplotlib
from matplotlib import cm, pyplot
def pack_attention_images(attention_weights, rotate=False):
attention_weights = np.pad(attention_weights,
[(0, 0), (1, 1), (1, 1)],
mode="constant",
constant_values=1.)
if rotate:
attention_weights = np.rot90(attention_weights, axes=(1, 2))
n, h, w = attention_weights.shape
ratio = h / w
if ratio < 1:
cols = max(int(np.sqrt(n / ratio)), 1)
rows = int(np.ceil(n / cols))
else:
rows = max(int(np.sqrt(n / ratio)), 1)
cols = int(np.ceil(n / rows))
extras = rows * cols - n
#print(rows, cols, extras)
total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0)
total = np.reshape(total, [rows, cols, h, w])
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
return img