Update rec_nrtr_optim_head.py
This commit is contained in:
parent
c635925895
commit
c8094e6575
|
@ -216,7 +216,7 @@ class TransformerOptim(nn.Layer):
|
|||
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||
|
||||
beamed_tensor = beamed_tensor.reshape(
|
||||
[n_prev_active_inst, -1]) #contiguous()
|
||||
[n_prev_active_inst, -1])
|
||||
beamed_tensor = beamed_tensor.index_select(
|
||||
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
||||
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
||||
|
@ -337,7 +337,7 @@ class TransformerOptim(nn.Layer):
|
|||
n_inst, len_s, d_h = src_enc.shape
|
||||
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
|
||||
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
|
||||
[1, 0, 2]) #repeat(1, n_bm, 1)
|
||||
[1, 0, 2])
|
||||
#-- Prepare beams
|
||||
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
||||
|
||||
|
|
Loading…
Reference in New Issue