This commit is contained in:
tlk-dsg 2021-12-14 16:20:31 +08:00
parent cf2b207cec
commit 446063806c
4 changed files with 6 additions and 6 deletions

View File

@ -60,7 +60,7 @@ def main(cfg):
model.to(device) model.to(device)
wandb.watch(model, log="all") wandb.watch(model, log="all")
lit_model = BertLitModel(args=cfg, model=model, tokenizer=data.tokenizer) lit_model = BertLitModel(args=cfg, model=model, device=device, tokenizer=data.tokenizer)
data.setup() data.setup()

View File

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name='deepke', # 打包后的包文件名 name='deepke', # 打包后的包文件名
version='0.2.94', #版本号 version='0.2.95', #版本号
keywords=["pip", "RE","NER","AE"], # 关键字 keywords=["pip", "RE","NER","AE"], # 关键字
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明 description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
long_description="client", #详细说明 long_description="client", #详细说明

View File

@ -22,11 +22,11 @@ class BaseLitModel(nn.Module):
Generic PyTorch-Lightning class that must be initialized with a PyTorch module. Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
""" """
def __init__(self, model, args): def __init__(self, model, device, args):
super().__init__() super().__init__()
self.model = model self.model = model
self.cur_model = model.module if hasattr(model, 'module') else model self.cur_model = model.module if hasattr(model, 'module') else model
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.device = device
self.args = args self.args = args
optimizer = self.args.get("optimizer", OPTIMIZER) optimizer = self.args.get("optimizer", OPTIMIZER)

View File

@ -36,8 +36,8 @@ class BertLitModel(BaseLitModel):
""" """
use AutoModelForMaskedLM, and select the output by another layer in the lit model use AutoModelForMaskedLM, and select the output by another layer in the lit model
""" """
def __init__(self, model, args, tokenizer): def __init__(self, model, device ,args, tokenizer):
super().__init__(model, args) super().__init__(model, device, args)
self.tokenizer = tokenizer self.tokenizer = tokenizer
with open(f"{args.data_dir}/rel2id.json","r") as file: with open(f"{args.data_dir}/rel2id.json","r") as file: