Add files via upload
This commit is contained in:
parent
7aa8d802d0
commit
031c6a75fb
|
@ -0,0 +1,55 @@
|
|||
## 快速上手
|
||||
|
||||
### 环境依赖
|
||||
|
||||
> python == 3.8
|
||||
|
||||
- tokenizers == 0.10.3
|
||||
- torch == 1.8.0
|
||||
- regex == 2021.4.4
|
||||
- transformers == 4.7.0
|
||||
- tqdm == 4.49.0
|
||||
- activations == 0.1.0
|
||||
- dataclasses == 0.6
|
||||
- file_utils == 0.0.1
|
||||
- flax == 0.3.4
|
||||
- utils == 1.0.1
|
||||
- deepke
|
||||
|
||||
### 克隆代码
|
||||
```
|
||||
git clone git@github.com:zjunlp/DeepKE.git
|
||||
```
|
||||
### 使用pip安装
|
||||
|
||||
首先创建python虚拟环境,再进入虚拟环境
|
||||
|
||||
- 安装依赖: ```pip install -r requirements.txt```
|
||||
|
||||
### 使用数据进行训练预测
|
||||
|
||||
- 存放数据:在 `data` 文件夹下存放训练数据。模型采用的数据集是[SEMEVAL](https://semeval2.fbk.eu/semeval2.php?location=tasks#T11),SEMEVAL数据集来自于2010年的国际语义评测大会中Task 8:"Multi-Way Classification of Semantic Relations Between Pairs of Nominals"。
|
||||
|
||||
- SEMEVAL包含以下数据:
|
||||
|
||||
- `rel2id.json`:关系标签到ID的映射
|
||||
|
||||
- `temp.txt`:关系标签处理
|
||||
|
||||
- `test.txt`: 测试集
|
||||
|
||||
- `train.txt`:训练集
|
||||
|
||||
- `val.txt`:验证集
|
||||
|
||||
- 开始训练:模型加载和保存位置以及配置可以在conf的`.yaml`文件中修改
|
||||
|
||||
- 对数据集SEMEVAL进行few-shot训练:`python run.py`
|
||||
|
||||
- 每次训练的日志保存路径可以通过`.yaml`中的log_dir来配置。
|
||||
|
||||
- 进行预测: `python predict.py `
|
||||
|
||||
|
||||
## 模型内容
|
||||
KnowPrompt
|
|
@ -0,0 +1,3 @@
|
|||
defaults:
|
||||
- hydra/output: custom
|
||||
- train
|
|
@ -0,0 +1,11 @@
|
|||
hydra:
|
||||
|
||||
run:
|
||||
# Output directory for normal runs
|
||||
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||
|
||||
sweep:
|
||||
# Output directory for sweep runs
|
||||
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||
# Output sub directory for sweep runs.
|
||||
subdir: ${hydra.job.num}_${hydra.job.id}
|
|
@ -0,0 +1,84 @@
|
|||
accelerator: None
|
||||
accumulate_grad_batches: '1'
|
||||
amp_backend: 'native'
|
||||
amp_level: 'O2'
|
||||
auto_lr_find: False
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: False
|
||||
batch_size: 16
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: '3'
|
||||
checkpoint_callback: True
|
||||
data_class: 'REDataset'
|
||||
data_dir: 'data/k-shot/8-1'
|
||||
default_root_dir: None
|
||||
deterministic: False
|
||||
devices: None
|
||||
distributed_backend: None
|
||||
fast_dev_run: False
|
||||
flush_logs_every_n_steps: 100
|
||||
gpus: None
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clip_algorithm: 'norm'
|
||||
gradient_clip_val: 0.0
|
||||
ipus: None
|
||||
limit_predict_batches: 1.0
|
||||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
litmodel_class: 'BertLitModel'
|
||||
load_checkpoint: None
|
||||
log_dir: './model_bert.log'
|
||||
log_every_n_steps: 50
|
||||
log_gpu_memory: None
|
||||
logger: True
|
||||
lr: 3e-05
|
||||
lr_2: 3e-05
|
||||
max_epochs: '30'
|
||||
max_seq_length: 256
|
||||
max_steps: None
|
||||
max_time: None
|
||||
min_epochs: None
|
||||
min_steps: None
|
||||
model_class: 'BertForMaskedLM'
|
||||
model_name_or_path: 'bert-large-uncased'
|
||||
move_metrics_to_cpu: False
|
||||
multiple_trainloader_mode: 'max_size_cycle'
|
||||
num_nodes: 1
|
||||
num_processes: 1
|
||||
num_sanity_val_steps: 2
|
||||
num_train_epochs: 30
|
||||
num_workers: 8
|
||||
optimizer: 'AdamW'
|
||||
overfit_batches: 0.0
|
||||
plugins: None
|
||||
precision: 32
|
||||
prepare_data_per_node: True
|
||||
process_position: 0
|
||||
profiler: None
|
||||
progress_bar_refresh_rate: None
|
||||
ptune_k: 7
|
||||
reload_dataloaders_every_epoch: False
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
resume_from_checkpoint: None
|
||||
save_path: './model_bert.pt'
|
||||
seed: 666
|
||||
stochastic_weight_avg: False
|
||||
sync_batchnorm: False
|
||||
t_lambda: 0.001
|
||||
task_name: 'wiki80'
|
||||
terminate_on_nan: False
|
||||
tpu_cores: None
|
||||
track_grad_norm: -1
|
||||
train_from_saved_model: ''
|
||||
truncated_bptt_steps: None
|
||||
two_steps: False
|
||||
use_prompt: True
|
||||
val_check_interval: 1.0
|
||||
wandb: False
|
||||
weight_decay: 0.01
|
||||
weights_save_path: None
|
||||
weights_summary: 'top'
|
||||
save_path: './model_bert.pt'
|
||||
load_path: './model_bert.pt'
|
|
@ -0,0 +1 @@
|
|||
{"Component-Whole(e2,e1)": 1, "Other": 0, "Instrument-Agency(e2,e1)": 2, "Member-Collection(e1,e2)": 3, "Cause-Effect(e2,e1)": 4, "Entity-Destination(e1,e2)": 5, "Content-Container(e1,e2)": 6, "Message-Topic(e1,e2)": 7, "Product-Producer(e2,e1)": 8, "Member-Collection(e2,e1)": 9, "Entity-Origin(e1,e2)": 10, "Cause-Effect(e1,e2)": 11, "Component-Whole(e1,e2)": 12, "Message-Topic(e2,e1)": 13, "Product-Producer(e1,e2)": 14, "Entity-Origin(e2,e1)": 15, "Content-Container(e2,e1)": 16, "Instrument-Agency(e1,e2)": 17, "Entity-Destination(e2,e1)": 18}
|
|
@ -0,0 +1,19 @@
|
|||
0 Other nothing has nothing to nothing
|
||||
0 Member-Collection(e1,e2) member member of collection collection
|
||||
0 Entity-Origin(e1,e2) entity entity of origin origin
|
||||
0 Cause-Effect(e1,e2) cause cause of effect effect
|
||||
0 Component-Whole(e1,e2) component component of whole whole
|
||||
0 Product-Producer(e1,e2) product product of producer producer
|
||||
0 Instrument-Agency(e1,e2) instrument instrument of agency agency
|
||||
0 Entity-Destination(e1,e2) entity entity of destination destination
|
||||
0 Content-Container(e1,e2) content content of container container
|
||||
0 Message-Topic(e1,e2) message message of topic topic
|
||||
2 Cause-Effect(e2,e1) effect effect of cause cause
|
||||
2 Product-Producer(e2,e1) producer producer of product product
|
||||
2 Component-Whole(e2,e1) whole whole of component component
|
||||
2 Instrument-Agency(e2,e1) agency agency of instrument instrument
|
||||
2 Member-Collection(e2,e1) collection collection of member member
|
||||
2 Message-Topic(e2,e1) topic topic of message message
|
||||
2 Entity-Origin(e2,e1) origin origin of entity entity
|
||||
2 Content-Container(e2,e1) container container of content content
|
||||
2 Entity-Destination(e2,e1) destination destination of entity entity
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,88 @@
|
|||
from logging import debug
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
import yaml
|
||||
import time
|
||||
from lit_models import TransformerLitModelTwoSteps
|
||||
from transformers import AutoConfig, AutoModel
|
||||
from transformers.optimization import get_linear_schedule_with_warmup
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
from deepkerefew import *
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
# In order to ensure reproducible experiments, we must set random seeds.
|
||||
|
||||
|
||||
def logging(log_dir, s, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_dir != '' and log_:
|
||||
with open(log_dir, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
|
||||
def test(args, model, lit_model, data):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_loss = []
|
||||
for test_index, test_batch in enumerate(tqdm(data.test_dataloader())):
|
||||
loss = lit_model.test_step(test_batch, test_index)
|
||||
test_loss.append(loss)
|
||||
f1 = lit_model.test_epoch_end(test_loss)
|
||||
logging(args.log_dir,
|
||||
'| test_result: {}'.format(f1))
|
||||
logging(args.log_dir,'-' * 89)
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/config.yaml")
|
||||
def main(cfg):
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
data = REDataset(cfg)
|
||||
data_config = data.get_data_config()
|
||||
|
||||
config = AutoConfig.from_pretrained(cfg.model_name_or_path)
|
||||
config.num_labels = data_config["num_labels"]
|
||||
|
||||
model = BertForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config)
|
||||
|
||||
if cfg.load_path != '':
|
||||
model.load_state_dict(torch.load(cfg.load_path)["checkpoint"])
|
||||
print("load saved model from {}.".format(cfg.load_path))
|
||||
|
||||
|
||||
# if torch.cuda.device_count() > 1:
|
||||
# print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
# model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))
|
||||
model.to(device)
|
||||
|
||||
cur_model = model.module if hasattr(model, 'module') else model
|
||||
|
||||
|
||||
if "gpt" in cfg.model_name_or_path or "roberta" in cfg.model_name_or_path:
|
||||
tokenizer = data.get_tokenizer()
|
||||
cur_model.resize_token_embeddings(len(tokenizer))
|
||||
cur_model.update_word_idx(len(tokenizer))
|
||||
if "Use" in cfg.model_class:
|
||||
continous_prompt = [a[0] for a in tokenizer([f"[T{i}]" for i in range(1,3)], add_special_tokens=False)['input_ids']]
|
||||
continous_label_word = [a[0] for a in tokenizer([f"[class{i}]" for i in range(1, data.num_labels+1)], add_special_tokens=False)['input_ids']]
|
||||
discrete_prompt = [a[0] for a in tokenizer(['It', 'was'], add_special_tokens=False)['input_ids']]
|
||||
dataset_name = cfg.data_dir.split("/")[1]
|
||||
model.init_unused_weights(continous_prompt, continous_label_word, discrete_prompt, label_path=f"{cfg.model_name_or_path}_{dataset_name}.pt")
|
||||
|
||||
lit_model = BertLitModel(cfg=cfg, model=model, tokenizer=data.tokenizer, device=device)
|
||||
data.setup()
|
||||
|
||||
|
||||
test(cfg, model, lit_model, data)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,11 @@
|
|||
numpy==1.20.3
|
||||
tokenizers==0.10.3
|
||||
torch==1.8.0
|
||||
regex==2021.4.4
|
||||
transformers==4.7.0
|
||||
tqdm==4.49.0
|
||||
activations==0.1.0
|
||||
dataclasses==0.6
|
||||
file_utils==0.0.1
|
||||
flax==0.3.4
|
||||
utils==1.0.1
|
|
@ -0,0 +1,145 @@
|
|||
from logging import debug
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
import yaml
|
||||
import time
|
||||
from lit_models import TransformerLitModelTwoSteps
|
||||
from transformers import AutoConfig, AutoModel
|
||||
from transformers.optimization import get_linear_schedule_with_warmup
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
from deepkerefew import *
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
# In order to ensure reproducible experiments, we must set random seeds.
|
||||
|
||||
|
||||
def logging(log_dir, s, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_dir != '' and log_:
|
||||
with open(log_dir, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/config.yaml")
|
||||
def main(cfg):
|
||||
get_label_word()
|
||||
generate_k_shot()
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
data = REDataset(cfg)
|
||||
data_config = data.get_data_config()
|
||||
|
||||
config = AutoConfig.from_pretrained(cfg.model_name_or_path)
|
||||
config.num_labels = data_config["num_labels"]
|
||||
|
||||
model = BertForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config)
|
||||
|
||||
if cfg.train_from_saved_model != '':
|
||||
model.load_state_dict(torch.load(cfg.train_from_saved_model)["checkpoint"])
|
||||
print("load saved model from {}.".format(cfg.train_from_saved_model))
|
||||
|
||||
|
||||
# if torch.cuda.device_count() > 1:
|
||||
# print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
# model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))
|
||||
model.to(device)
|
||||
|
||||
cur_model = model.module if hasattr(model, 'module') else model
|
||||
|
||||
|
||||
if "gpt" in cfg.model_name_or_path or "roberta" in cfg.model_name_or_path:
|
||||
tokenizer = data.get_tokenizer()
|
||||
cur_model.resize_token_embeddings(len(tokenizer))
|
||||
cur_model.update_word_idx(len(tokenizer))
|
||||
if "Use" in cfg.model_class:
|
||||
continous_prompt = [a[0] for a in tokenizer([f"[T{i}]" for i in range(1,3)], add_special_tokens=False)['input_ids']]
|
||||
continous_label_word = [a[0] for a in tokenizer([f"[class{i}]" for i in range(1, data.num_labels+1)], add_special_tokens=False)['input_ids']]
|
||||
discrete_prompt = [a[0] for a in tokenizer(['It', 'was'], add_special_tokens=False)['input_ids']]
|
||||
dataset_name = cfg.data_dir.split("/")[1]
|
||||
model.init_unused_weights(continous_prompt, continous_label_word, discrete_prompt, label_path=f"{cfg.model_name_or_path}_{dataset_name}.pt")
|
||||
|
||||
lit_model = BertLitModel(cfg=cfg, model=model, tokenizer=data.tokenizer, device=device)
|
||||
if cfg.train_from_saved_model != '':
|
||||
lit_model.best_f1 = torch.load(cfg.train_from_saved_model)["best_f1"]
|
||||
data.tokenizer.save_pretrained('test')
|
||||
data.setup()
|
||||
|
||||
optimizer = lit_model.configure_optimizers()
|
||||
if cfg.train_from_saved_model != '':
|
||||
optimizer.load_state_dict(torch.load(cfg.train_from_saved_model)["optimizer"])
|
||||
print("load saved optimizer from {}.".format(cfg.train_from_saved_model))
|
||||
|
||||
num_training_steps = len(data.train_dataloader()) // cfg.gradient_accumulation_steps * cfg.num_train_epochs
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)
|
||||
log_step = 100
|
||||
|
||||
|
||||
logging(cfg.log_dir,'-' * 89, print_=False)
|
||||
logging(cfg.log_dir, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' INFO : START TO TRAIN ', print_=False)
|
||||
logging(cfg.log_dir,'-' * 89, print_=False)
|
||||
|
||||
for epoch in range(cfg.num_train_epochs):
|
||||
model.train()
|
||||
num_batch = len(data.train_dataloader())
|
||||
total_loss = 0
|
||||
log_loss = 0
|
||||
for index, train_batch in enumerate(tqdm(data.train_dataloader())):
|
||||
loss = lit_model.training_step(train_batch, index)
|
||||
total_loss += loss.item()
|
||||
log_loss += loss.item()
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if log_step > 0 and (index+1) % log_step == 0:
|
||||
cur_loss = log_loss / log_step
|
||||
logging(cfg.log_dir,
|
||||
'| epoch {:2d} | step {:4d} | lr {} | train loss {:5.3f}'.format(
|
||||
epoch, (index+1), scheduler.get_last_lr(), cur_loss * 1000)
|
||||
, print_=False)
|
||||
log_loss = 0
|
||||
avrg_loss = total_loss / num_batch
|
||||
logging(cfg.log_dir,
|
||||
'| epoch {:2d} | train loss {:5.3f}'.format(
|
||||
epoch, avrg_loss * 1000))
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_loss = []
|
||||
for val_index, val_batch in enumerate(tqdm(data.val_dataloader())):
|
||||
loss = lit_model.validation_step(val_batch, val_index)
|
||||
val_loss.append(loss)
|
||||
f1, best, best_f1 = lit_model.validation_epoch_end(val_loss)
|
||||
logging(cfg.log_dir,'-' * 89)
|
||||
logging(cfg.log_dir,
|
||||
'| epoch {:2d} | dev_result: {}'.format(epoch, f1))
|
||||
logging(cfg.log_dir,'-' * 89)
|
||||
logging(cfg.log_dir,
|
||||
'| best_f1: {}'.format(best_f1))
|
||||
logging(cfg.log_dir,'-' * 89)
|
||||
if cfg.save_path != "" and best != -1:
|
||||
save_path = cfg.save_path
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'checkpoint': cur_model.state_dict(),
|
||||
'best_f1': best_f1,
|
||||
'optimizer': optimizer.state_dict()
|
||||
}, save_path
|
||||
, _use_new_zipfile_serialization=False)
|
||||
logging(cfg.log_dir,
|
||||
'| successfully save model at: {}'.format(save_path))
|
||||
logging(cfg.log_dir,'-' * 89)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue