fix slice in sar head
This commit is contained in:
parent
f4b3d49bd4
commit
902fffcc7e
|
@ -235,7 +235,8 @@ class ParallelSARDecoder(BaseDecoder):
|
|||
# cal mask of attention weight
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
attn_weight[i, :, :, valid_width:, :] = float('-inf')
|
||||
if valid_width < w:
|
||||
attn_weight[i, :, :, valid_width:, :] = float('-inf')
|
||||
|
||||
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
|
||||
attn_weight = F.softmax(attn_weight, axis=-1)
|
||||
|
|
Loading…
Reference in New Issue