diff --git a/tools/predict.py b/tools/predict.py index db75059..4e3078d 100644 --- a/tools/predict.py +++ b/tools/predict.py @@ -119,6 +119,10 @@ def main(cfg): if cfg.model_name == 'cnn': if cfg.use_pcnn: x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']]) + if cfg.model_name == 'gcn': + # 没找到合适的做 parsing tree 的工具,暂时随机初始化 + adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2) + x['adj'] = adj for key in x.keys(): x[key] = x[key].to(device)