Add files via upload

This commit is contained in:
TimelordRi 2021-10-09 15:00:00 +08:00 committed by GitHub
parent 46ec625506
commit 98bf963b22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 11 deletions

View File

@ -138,7 +138,9 @@
"cell_type": "code",
"execution_count": null,
"source": [
"docred_rel2id = json.load(open('./data/rel2id.json', 'r'))\n",
"rel2id = json.load(open('./data/rel2id.json', 'r'))\n",
"id2rel = {value: key for key, value in rel2id.items()}\n",
"\n",
"\n",
"def chunks(l, n):\n",
" res = []\n",
@ -243,7 +245,7 @@
" if \"labels\" in sample:\n",
" for label in sample['labels']:\n",
" evidence = label['evidence']\n",
" r = int(docred_rel2id[label['r']])\n",
" r = int(rel2id[label['r']])\n",
" if (label['h'], label['t']) not in train_triple:\n",
" train_triple[(label['h'], label['t'])] = [\n",
" {'relation': r, 'evidence': evidence}]\n",
@ -264,7 +266,7 @@
" relations, hts = [], []\n",
" # Get positive samples from dataset\n",
" for h, t in train_triple.keys():\n",
" relation = [0] * len(docred_rel2id)\n",
" relation = [0] * len(rel2id)\n",
" for mention in train_triple[h, t]:\n",
" relation[mention[\"relation\"]] = 1\n",
" evidence = mention[\"evidence\"]\n",
@ -276,7 +278,7 @@
" for h in range(len(entities)):\n",
" for t in range(len(entities)):\n",
" if h != t and [h, t] not in hts:\n",
" relation = [1] + [0] * (len(docred_rel2id) - 1)\n",
" relation = [1] + [0] * (len(rel2id) - 1)\n",
" relations.append(relation)\n",
" hts.append([h, t])\n",
" neg_samples += 1\n",
@ -843,24 +845,26 @@
" )\n",
" return res\n",
"\n",
"def report(cfg, model, features):\n",
"def report(args, model, features):\n",
"\n",
" dataloader = DataLoader(features, batch_size=cfg.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n",
" dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n",
" preds = []\n",
" for batch in dataloader:\n",
" model.eval()\n",
" inputs = {'input_ids': batch[0].to(cfg.device),\n",
" 'attention_mask': batch[1].to(cfg.device),\n",
"\n",
" inputs = {'input_ids': batch[0].to(args.device),\n",
" 'attention_mask': batch[1].to(args.device),\n",
" 'entity_pos': batch[3],\n",
" 'hts': batch[4],\n",
" }\n",
"\n",
" with torch.no_grad():\n",
" pred, *_ = model(**inputs)\n",
" pred = model(**inputs)\n",
" pred = pred.cpu().numpy()\n",
" pred[np.isnan(pred)] = 0\n",
" preds.append(pred)\n",
" preds = np.array(preds).astype(np.float32)\n",
"\n",
" preds = np.concatenate(preds, axis=0).astype(np.float32)\n",
" preds = to_official(preds, features)\n",
" return preds\n",
"\n",
@ -1164,4 +1168,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}