Update rec_nrtr_optim_head.py

This commit is contained in:
topduke 2021-08-19 19:08:23 +08:00 committed by GitHub
parent c635925895
commit c8094e6575
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -216,7 +216,7 @@ class TransformerOptim(nn.Layer):
new_shape = (n_curr_active_inst * n_bm, *d_hs)
beamed_tensor = beamed_tensor.reshape(
[n_prev_active_inst, -1]) #contiguous()
[n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape])
@ -337,7 +337,7 @@ class TransformerOptim(nn.Layer):
n_inst, len_s, d_h = src_enc.shape
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
[1, 0, 2]) #repeat(1, n_bm, 1)
[1, 0, 2])
#-- Prepare beams
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]