Merge branch 'reborn' into exp
This commit is contained in:
commit
1f71f65c28
|
@ -0,0 +1,27 @@
|
|||
import numpy as np
|
||||
import matplotlib
|
||||
from matplotlib import cm, pyplot
|
||||
|
||||
def pack_attention_images(attention_weights, rotate=False):
|
||||
attention_weights = np.pad(attention_weights,
|
||||
[(0, 0), (1, 1), (1, 1)],
|
||||
mode="constant",
|
||||
constant_values=1.)
|
||||
if rotate:
|
||||
attention_weights = np.rot90(attention_weights, axes=(1, 2))
|
||||
n, h, w = attention_weights.shape
|
||||
|
||||
ratio = h / w
|
||||
if ratio < 1:
|
||||
cols = max(int(np.sqrt(n / ratio)), 1)
|
||||
rows = int(np.ceil(n / cols))
|
||||
else:
|
||||
rows = max(int(np.sqrt(n / ratio)), 1)
|
||||
cols = int(np.ceil(n / rows))
|
||||
extras = rows * cols - n
|
||||
#print(rows, cols, extras)
|
||||
total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0)
|
||||
total = np.reshape(total, [rows, cols, h, w])
|
||||
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
||||
return img
|
||||
|
Loading…
Reference in New Issue