From 446063806c265431aa8795503938dd189103a781 Mon Sep 17 00:00:00 2001 From: tlk-dsg <467460833@qq.com> Date: Tue, 14 Dec 2021 16:20:31 +0800 Subject: [PATCH] fix bug --- example/re/few-shot/run.py | 2 +- setup.py | 2 +- src/deepke/relation_extraction/few_shot/lit_models/base.py | 4 ++-- .../relation_extraction/few_shot/lit_models/transformer.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/re/few-shot/run.py b/example/re/few-shot/run.py index 809634b..eef1960 100644 --- a/example/re/few-shot/run.py +++ b/example/re/few-shot/run.py @@ -60,7 +60,7 @@ def main(cfg): model.to(device) wandb.watch(model, log="all") - lit_model = BertLitModel(args=cfg, model=model, tokenizer=data.tokenizer) + lit_model = BertLitModel(args=cfg, model=model, device=device, tokenizer=data.tokenizer) data.setup() diff --git a/setup.py b/setup.py index 01a5891..804c9c1 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='deepke', # 打包后的包文件名 - version='0.2.94', #版本号 + version='0.2.95', #版本号 keywords=["pip", "RE","NER","AE"], # 关键字 description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明 long_description="client", #详细说明 diff --git a/src/deepke/relation_extraction/few_shot/lit_models/base.py b/src/deepke/relation_extraction/few_shot/lit_models/base.py index d98e082..df5cb5a 100644 --- a/src/deepke/relation_extraction/few_shot/lit_models/base.py +++ b/src/deepke/relation_extraction/few_shot/lit_models/base.py @@ -22,11 +22,11 @@ class BaseLitModel(nn.Module): Generic PyTorch-Lightning class that must be initialized with a PyTorch module. """ - def __init__(self, model, args): + def __init__(self, model, device, args): super().__init__() self.model = model self.cur_model = model.module if hasattr(model, 'module') else model - self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.device = device self.args = args optimizer = self.args.get("optimizer", OPTIMIZER) diff --git a/src/deepke/relation_extraction/few_shot/lit_models/transformer.py b/src/deepke/relation_extraction/few_shot/lit_models/transformer.py index 8c86680..9dff611 100644 --- a/src/deepke/relation_extraction/few_shot/lit_models/transformer.py +++ b/src/deepke/relation_extraction/few_shot/lit_models/transformer.py @@ -36,8 +36,8 @@ class BertLitModel(BaseLitModel): """ use AutoModelForMaskedLM, and select the output by another layer in the lit model """ - def __init__(self, model, args, tokenizer): - super().__init__(model, args) + def __init__(self, model, device ,args, tokenizer): + super().__init__(model, device, args) self.tokenizer = tokenizer with open(f"{args.data_dir}/rel2id.json","r") as file: