Add files via upload
This commit is contained in:
parent
46ec625506
commit
98bf963b22
|
@ -138,7 +138,9 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"docred_rel2id = json.load(open('./data/rel2id.json', 'r'))\n",
|
||||
"rel2id = json.load(open('./data/rel2id.json', 'r'))\n",
|
||||
"id2rel = {value: key for key, value in rel2id.items()}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def chunks(l, n):\n",
|
||||
" res = []\n",
|
||||
|
@ -243,7 +245,7 @@
|
|||
" if \"labels\" in sample:\n",
|
||||
" for label in sample['labels']:\n",
|
||||
" evidence = label['evidence']\n",
|
||||
" r = int(docred_rel2id[label['r']])\n",
|
||||
" r = int(rel2id[label['r']])\n",
|
||||
" if (label['h'], label['t']) not in train_triple:\n",
|
||||
" train_triple[(label['h'], label['t'])] = [\n",
|
||||
" {'relation': r, 'evidence': evidence}]\n",
|
||||
|
@ -264,7 +266,7 @@
|
|||
" relations, hts = [], []\n",
|
||||
" # Get positive samples from dataset\n",
|
||||
" for h, t in train_triple.keys():\n",
|
||||
" relation = [0] * len(docred_rel2id)\n",
|
||||
" relation = [0] * len(rel2id)\n",
|
||||
" for mention in train_triple[h, t]:\n",
|
||||
" relation[mention[\"relation\"]] = 1\n",
|
||||
" evidence = mention[\"evidence\"]\n",
|
||||
|
@ -276,7 +278,7 @@
|
|||
" for h in range(len(entities)):\n",
|
||||
" for t in range(len(entities)):\n",
|
||||
" if h != t and [h, t] not in hts:\n",
|
||||
" relation = [1] + [0] * (len(docred_rel2id) - 1)\n",
|
||||
" relation = [1] + [0] * (len(rel2id) - 1)\n",
|
||||
" relations.append(relation)\n",
|
||||
" hts.append([h, t])\n",
|
||||
" neg_samples += 1\n",
|
||||
|
@ -843,24 +845,26 @@
|
|||
" )\n",
|
||||
" return res\n",
|
||||
"\n",
|
||||
"def report(cfg, model, features):\n",
|
||||
"def report(args, model, features):\n",
|
||||
"\n",
|
||||
" dataloader = DataLoader(features, batch_size=cfg.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n",
|
||||
" dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n",
|
||||
" preds = []\n",
|
||||
" for batch in dataloader:\n",
|
||||
" model.eval()\n",
|
||||
" inputs = {'input_ids': batch[0].to(cfg.device),\n",
|
||||
" 'attention_mask': batch[1].to(cfg.device),\n",
|
||||
"\n",
|
||||
" inputs = {'input_ids': batch[0].to(args.device),\n",
|
||||
" 'attention_mask': batch[1].to(args.device),\n",
|
||||
" 'entity_pos': batch[3],\n",
|
||||
" 'hts': batch[4],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" pred, *_ = model(**inputs)\n",
|
||||
" pred = model(**inputs)\n",
|
||||
" pred = pred.cpu().numpy()\n",
|
||||
" pred[np.isnan(pred)] = 0\n",
|
||||
" preds.append(pred)\n",
|
||||
" preds = np.array(preds).astype(np.float32)\n",
|
||||
"\n",
|
||||
" preds = np.concatenate(preds, axis=0).astype(np.float32)\n",
|
||||
" preds = to_official(preds, features)\n",
|
||||
" return preds\n",
|
||||
"\n",
|
||||
|
@ -1164,4 +1168,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue