update gcn

This commit is contained in:
tlk-dsg 2021-05-20 14:12:11 +08:00
parent 39df279b55
commit 33a777cddf
1 changed files with 4 additions and 0 deletions

View File

@ -119,6 +119,10 @@ def main(cfg):
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
if cfg.model_name == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
x['adj'] = adj
for key in x.keys():
x[key] = x[key].to(device)