fix bug
This commit is contained in:
parent
cf2b207cec
commit
446063806c
|
@ -60,7 +60,7 @@ def main(cfg):
|
||||||
model.to(device)
|
model.to(device)
|
||||||
wandb.watch(model, log="all")
|
wandb.watch(model, log="all")
|
||||||
|
|
||||||
lit_model = BertLitModel(args=cfg, model=model, tokenizer=data.tokenizer)
|
lit_model = BertLitModel(args=cfg, model=model, device=device, tokenizer=data.tokenizer)
|
||||||
|
|
||||||
|
|
||||||
data.setup()
|
data.setup()
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='deepke', # 打包后的包文件名
|
name='deepke', # 打包后的包文件名
|
||||||
version='0.2.94', #版本号
|
version='0.2.95', #版本号
|
||||||
keywords=["pip", "RE","NER","AE"], # 关键字
|
keywords=["pip", "RE","NER","AE"], # 关键字
|
||||||
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
|
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
|
||||||
long_description="client", #详细说明
|
long_description="client", #详细说明
|
||||||
|
|
|
@ -22,11 +22,11 @@ class BaseLitModel(nn.Module):
|
||||||
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
|
Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, args):
|
def __init__(self, model, device, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.cur_model = model.module if hasattr(model, 'module') else model
|
self.cur_model = model.module if hasattr(model, 'module') else model
|
||||||
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
self.device = device
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
optimizer = self.args.get("optimizer", OPTIMIZER)
|
optimizer = self.args.get("optimizer", OPTIMIZER)
|
||||||
|
|
|
@ -36,8 +36,8 @@ class BertLitModel(BaseLitModel):
|
||||||
"""
|
"""
|
||||||
use AutoModelForMaskedLM, and select the output by another layer in the lit model
|
use AutoModelForMaskedLM, and select the output by another layer in the lit model
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, args, tokenizer):
|
def __init__(self, model, device ,args, tokenizer):
|
||||||
super().__init__(model, args)
|
super().__init__(model, device, args)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
with open(f"{args.data_dir}/rel2id.json","r") as file:
|
with open(f"{args.data_dir}/rel2id.json","r") as file:
|
||||||
|
|
Loading…
Reference in New Issue