This commit is contained in:
leo 2019-12-06 10:43:54 +08:00
parent 0786b2b30e
commit a0cbc2a681
1 changed files with 1 additions and 1 deletions

View File

@ -37,7 +37,7 @@ def collate_fn(cfg):
x['tail_pos'] = torch.tensor(tail_pos)
if cfg.model_name == 'cnn' and cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor(pcnn_mask)
if cfg.model == 'gcn':
if cfg.model_name == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
B, L = len(batch), max_len
adj = torch.empty(B, L, L).random_(2)