fix bug
This commit is contained in:
parent
cf2b207cec
commit
446063806c
|
@ -60,7 +60,7 @@ def main(cfg):
|
|||
model.to(device)
|
||||
wandb.watch(model, log="all")
|
||||
|
||||
lit_model = BertLitModel(args=cfg, model=model, tokenizer=data.tokenizer)
|
||||
lit_model = BertLitModel(args=cfg, model=model, device=device, tokenizer=data.tokenizer)
|
||||
|
||||
|
||||
data.setup()
|
||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name='deepke', # 打包后的包文件名
|
||||
version='0.2.94', #版本号
|
||||
version='0.2.95', #版本号
|
||||
keywords=["pip", "RE","NER","AE"], # 关键字
|
||||
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
|
||||
long_description="client", #详细说明
|
||||
|
|
|
@ -22,11 +22,11 @@ class BaseLitModel(nn.Module):
|
|||
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
|
||||
"""
|
||||
|
||||
def __init__(self, model, args):
|
||||
def __init__(self, model, device, args):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.cur_model = model.module if hasattr(model, 'module') else model
|
||||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device
|
||||
self.args = args
|
||||
|
||||
optimizer = self.args.get("optimizer", OPTIMIZER)
|
||||
|
|
|
@ -36,8 +36,8 @@ class BertLitModel(BaseLitModel):
|
|||
"""
|
||||
use AutoModelForMaskedLM, and select the output by another layer in the lit model
|
||||
"""
|
||||
def __init__(self, model, args, tokenizer):
|
||||
super().__init__(model, args)
|
||||
def __init__(self, model, device ,args, tokenizer):
|
||||
super().__init__(model, device, args)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
with open(f"{args.data_dir}/rel2id.json","r") as file:
|
||||
|
|
Loading…
Reference in New Issue