Merge branch 'reborn' into exp

This commit is contained in:
chenfeiyu 2020-10-30 21:44:06 +08:00
commit 1f71f65c28
1 changed files with 27 additions and 0 deletions

27
parakeet/utils/display.py Normal file
View File

@ -0,0 +1,27 @@
import numpy as np
import matplotlib
from matplotlib import cm, pyplot
def pack_attention_images(attention_weights, rotate=False):
attention_weights = np.pad(attention_weights,
[(0, 0), (1, 1), (1, 1)],
mode="constant",
constant_values=1.)
if rotate:
attention_weights = np.rot90(attention_weights, axes=(1, 2))
n, h, w = attention_weights.shape
ratio = h / w
if ratio < 1:
cols = max(int(np.sqrt(n / ratio)), 1)
rows = int(np.ceil(n / cols))
else:
rows = max(int(np.sqrt(n / ratio)), 1)
cols = int(np.ceil(n / rows))
extras = rows * cols - n
#print(rows, cols, extras)
total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0)
total = np.reshape(total, [rows, cols, h, w])
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
return img