fix bug
This commit is contained in:
parent
0786b2b30e
commit
a0cbc2a681
|
@ -37,7 +37,7 @@ def collate_fn(cfg):
|
|||
x['tail_pos'] = torch.tensor(tail_pos)
|
||||
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
||||
if cfg.model == 'gcn':
|
||||
if cfg.model_name == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
B, L = len(batch), max_len
|
||||
adj = torch.empty(B, L, L).random_(2)
|
||||
|
|
Loading…
Reference in New Issue