fix bug
This commit is contained in:
parent
0786b2b30e
commit
a0cbc2a681
|
@ -37,7 +37,7 @@ def collate_fn(cfg):
|
||||||
x['tail_pos'] = torch.tensor(tail_pos)
|
x['tail_pos'] = torch.tensor(tail_pos)
|
||||||
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
||||||
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
||||||
if cfg.model == 'gcn':
|
if cfg.model_name == 'gcn':
|
||||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||||
B, L = len(batch), max_len
|
B, L = len(batch), max_len
|
||||||
adj = torch.empty(B, L, L).random_(2)
|
adj = torch.empty(B, L, L).random_(2)
|
||||||
|
|
Loading…
Reference in New Issue