update gcn
This commit is contained in:
parent
39df279b55
commit
33a777cddf
|
@ -119,6 +119,10 @@ def main(cfg):
|
|||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||
if cfg.model_name == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||
x['adj'] = adj
|
||||
|
||||
for key in x.keys():
|
||||
x[key] = x[key].to(device)
|
||||
|
|
Loading…
Reference in New Issue