diff --git a/tutorial-notebooks/re/document/tutorial.ipynb b/tutorial-notebooks/re/document/tutorial.ipynb index 9eccc03..2564ea7 100644 --- a/tutorial-notebooks/re/document/tutorial.ipynb +++ b/tutorial-notebooks/re/document/tutorial.ipynb @@ -138,7 +138,9 @@ "cell_type": "code", "execution_count": null, "source": [ - "docred_rel2id = json.load(open('./data/rel2id.json', 'r'))\n", + "rel2id = json.load(open('./data/rel2id.json', 'r'))\n", + "id2rel = {value: key for key, value in rel2id.items()}\n", + "\n", "\n", "def chunks(l, n):\n", " res = []\n", @@ -243,7 +245,7 @@ " if \"labels\" in sample:\n", " for label in sample['labels']:\n", " evidence = label['evidence']\n", - " r = int(docred_rel2id[label['r']])\n", + " r = int(rel2id[label['r']])\n", " if (label['h'], label['t']) not in train_triple:\n", " train_triple[(label['h'], label['t'])] = [\n", " {'relation': r, 'evidence': evidence}]\n", @@ -264,7 +266,7 @@ " relations, hts = [], []\n", " # Get positive samples from dataset\n", " for h, t in train_triple.keys():\n", - " relation = [0] * len(docred_rel2id)\n", + " relation = [0] * len(rel2id)\n", " for mention in train_triple[h, t]:\n", " relation[mention[\"relation\"]] = 1\n", " evidence = mention[\"evidence\"]\n", @@ -276,7 +278,7 @@ " for h in range(len(entities)):\n", " for t in range(len(entities)):\n", " if h != t and [h, t] not in hts:\n", - " relation = [1] + [0] * (len(docred_rel2id) - 1)\n", + " relation = [1] + [0] * (len(rel2id) - 1)\n", " relations.append(relation)\n", " hts.append([h, t])\n", " neg_samples += 1\n", @@ -843,24 +845,26 @@ " )\n", " return res\n", "\n", - "def report(cfg, model, features):\n", + "def report(args, model, features):\n", "\n", - " dataloader = DataLoader(features, batch_size=cfg.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n", + " dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n", " preds = []\n", " for batch in dataloader:\n", " model.eval()\n", - " inputs = {'input_ids': batch[0].to(cfg.device),\n", - " 'attention_mask': batch[1].to(cfg.device),\n", + "\n", + " inputs = {'input_ids': batch[0].to(args.device),\n", + " 'attention_mask': batch[1].to(args.device),\n", " 'entity_pos': batch[3],\n", " 'hts': batch[4],\n", " }\n", "\n", " with torch.no_grad():\n", - " pred, *_ = model(**inputs)\n", + " pred = model(**inputs)\n", " pred = pred.cpu().numpy()\n", " pred[np.isnan(pred)] = 0\n", " preds.append(pred)\n", - " preds = np.array(preds).astype(np.float32)\n", + "\n", + " preds = np.concatenate(preds, axis=0).astype(np.float32)\n", " preds = to_official(preds, features)\n", " return preds\n", "\n", @@ -1164,4 +1168,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file