diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py new file mode 100644 index 0000000..e32aaa7 --- /dev/null +++ b/parakeet/utils/display.py @@ -0,0 +1,27 @@ +import numpy as np +import matplotlib +from matplotlib import cm, pyplot + +def pack_attention_images(attention_weights, rotate=False): + attention_weights = np.pad(attention_weights, + [(0, 0), (1, 1), (1, 1)], + mode="constant", + constant_values=1.) + if rotate: + attention_weights = np.rot90(attention_weights, axes=(1, 2)) + n, h, w = attention_weights.shape + + ratio = h / w + if ratio < 1: + cols = max(int(np.sqrt(n / ratio)), 1) + rows = int(np.ceil(n / cols)) + else: + rows = max(int(np.sqrt(n / ratio)), 1) + cols = int(np.ceil(n / rows)) + extras = rows * cols - n + #print(rows, cols, extras) + total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0) + total = np.reshape(total, [rows, cols, h, w]) + img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)]) + return img +