parent
ad48c5f7b6
commit
7d487e3441
|
@ -0,0 +1,116 @@
|
||||||
|
# KnowPrompt
|
||||||
|
|
||||||
|
|
||||||
|
Code and datasets for our paper "KnowPrompt: Knowledge-aware Prompt-tuning with Synergistic Optimization for Relation Extraction"
|
||||||
|
|
||||||
|
Requirements
|
||||||
|
==========
|
||||||
|
To install requirements:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Datasets
|
||||||
|
==========
|
||||||
|
|
||||||
|
We provide all the datasets and prompts used in our experiments.
|
||||||
|
|
||||||
|
+ [[SEMEVAL]](dataset/semeval)
|
||||||
|
|
||||||
|
+ [[DialogRE]](../datasets/dialogue)
|
||||||
|
|
||||||
|
+ [[TACRED-Revisit]](../datasets/tacrev)
|
||||||
|
|
||||||
|
+ [[Re-TACRED]](../datasets/retacred)
|
||||||
|
|
||||||
|
+ [[Wiki80]](../datasets/wiki80)
|
||||||
|
|
||||||
|
The expected structure of files is:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
knowprompt
|
||||||
|
|-- dataset
|
||||||
|
| |-- semeval
|
||||||
|
| | |-- train.txt
|
||||||
|
| | |-- dev.txt
|
||||||
|
| | |-- test.txt
|
||||||
|
| | |-- temp.txt
|
||||||
|
| | |-- rel2id.json
|
||||||
|
| |-- dialogue
|
||||||
|
| | |-- train.json
|
||||||
|
| | |-- dev.json
|
||||||
|
| | |-- test.json
|
||||||
|
| | |-- rel2id.json
|
||||||
|
| |-- wiki80
|
||||||
|
| | |-- train.txt
|
||||||
|
| | |-- dev.txt
|
||||||
|
| | |-- test.txt
|
||||||
|
| | |-- temp.txt
|
||||||
|
| | |-- rel2id.json
|
||||||
|
| |-- tacrev
|
||||||
|
| | |-- train.txt
|
||||||
|
| | |-- dev.txt
|
||||||
|
| | |-- test.txt
|
||||||
|
| | |-- temp.txt
|
||||||
|
| | |-- rel2id.json
|
||||||
|
| |-- retacred
|
||||||
|
| | |-- train.txt
|
||||||
|
| | |-- dev.txt
|
||||||
|
| | |-- test.txt
|
||||||
|
| | |-- temp.txt
|
||||||
|
| | |-- rel2id.json
|
||||||
|
|-- scripts
|
||||||
|
| |-- semeval.sh
|
||||||
|
| |-- dialogue.sh
|
||||||
|
| |-- ...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Run the experiments
|
||||||
|
==========
|
||||||
|
|
||||||
|
## Initialize the answer words
|
||||||
|
|
||||||
|
Use the comand below to get the answer words to use in the training.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python get_label_word.py --model_name_or_path bert-large-uncased --dataset_name semeval
|
||||||
|
```
|
||||||
|
|
||||||
|
The `{answer_words}.pt`will be saved in the dataset, you need to assign the `model_name_or_path` and `dataset_name` in the `get_label_word.py`.
|
||||||
|
|
||||||
|
## Split dataset
|
||||||
|
|
||||||
|
Download the data first, and put it to `dataset` folder. Run the comand below, and get the few shot dataset.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python generate_k_shot.py --data_dir ./dataset --k 8 --dataset semeval
|
||||||
|
cd dataset/semeval
|
||||||
|
cp rel2id.json val.txt test.txt ./k-shot/8-1
|
||||||
|
```
|
||||||
|
|
||||||
|
You need to modify the `k` and `dataset` to assign k-shot and dataset. Here we default seed as 1,2,3,4,5 to split each k-shot, you can revise it in the `generate_k_shot.py`
|
||||||
|
|
||||||
|
## Let's run
|
||||||
|
|
||||||
|
Our script code can automatically run the experiments in 8-shot, 16-shot, 32-shot and
|
||||||
|
standard supervised settings with both the procedures of train, eval and test. We just choose the random seed to be 1 as an example in our code. Actually you can perform multiple experments with different seeds.
|
||||||
|
|
||||||
|
#### Example for SEMEVAL
|
||||||
|
Train the KonwPrompt model on SEMEVAL with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
>> bash scripts/semeval.sh # for bert-large-uncased
|
||||||
|
```
|
||||||
|
As the scripts for `TACRED-Revist`, `Re-TACRED`, `Wiki80` included in our paper are also provided, you just need to run it like above example.
|
||||||
|
|
||||||
|
#### Example for DialogRE
|
||||||
|
As the data format of DialogRE is very different from other dataset, Class of processor is also different.
|
||||||
|
Train the KonwPrompt model on DialogRE with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
>> bash scripts/dialogue.sh # for bert-base-uncased
|
||||||
|
```
|
|
@ -0,0 +1,89 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from pandas import DataFrame
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from collections import Counter, OrderedDict
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_labels(path, name, negative_label="no_relation"):
|
||||||
|
"""See base class."""
|
||||||
|
|
||||||
|
count = Counter()
|
||||||
|
with open(path + "/" + name, "r") as f:
|
||||||
|
features = []
|
||||||
|
for line in f.readlines():
|
||||||
|
line = line.rstrip()
|
||||||
|
if len(line) > 0:
|
||||||
|
# count[line['relation']] += 1
|
||||||
|
features.append(eval(line))
|
||||||
|
|
||||||
|
# logger.info("label distribution as list: %d labels" % len(count))
|
||||||
|
# # Make sure the negative label is alwyas 0
|
||||||
|
# labels = []
|
||||||
|
# for label, count in count.most_common():
|
||||||
|
# logger.info("%s: %d 个 %.2f%%" % (label, count, count * 100.0 / len(dataset)))
|
||||||
|
# if label not in labels:
|
||||||
|
# labels.append(label)
|
||||||
|
return features
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--k", type=int, default=16,
|
||||||
|
help="Training examples for each class.")
|
||||||
|
# parser.add_argument("--task", type=str, nargs="+",
|
||||||
|
# default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
|
||||||
|
# help="Task names")
|
||||||
|
parser.add_argument("--seed", type=int, nargs="+",
|
||||||
|
default=[1, 2, 3, 4, 5],
|
||||||
|
help="Random seeds")
|
||||||
|
|
||||||
|
parser.add_argument("--data_dir", type=str, default="../datasets/", help="Path to original data")
|
||||||
|
parser.add_argument("--dataset", type=str, default="tacred", help="Path to original data")
|
||||||
|
parser.add_argument("--data_file", type=str, default='train.txt', choices=['train.txt', 'val.txt'], help="k-shot or k-shot-10x (10x dev set)")
|
||||||
|
|
||||||
|
parser.add_argument("--mode", type=str, default='k-shot', choices=['k-shot', 'k-shot-10x'], help="k-shot or k-shot-10x (10x dev set)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
path = os.path.join(args.data_dir, args.dataset)
|
||||||
|
|
||||||
|
output_dir = os.path.join(path, args.mode)
|
||||||
|
dataset = get_labels(path, args.data_file)
|
||||||
|
|
||||||
|
for seed in args.seed:
|
||||||
|
|
||||||
|
# Other datasets
|
||||||
|
np.random.seed(seed)
|
||||||
|
np.random.shuffle(dataset)
|
||||||
|
|
||||||
|
# Set up dir
|
||||||
|
k = args.k
|
||||||
|
setting_dir = os.path.join(output_dir, f"{k}-{seed}")
|
||||||
|
os.makedirs(setting_dir, exist_ok=True)
|
||||||
|
|
||||||
|
label_list = {}
|
||||||
|
for line in dataset:
|
||||||
|
label = line['relation']
|
||||||
|
if label not in label_list:
|
||||||
|
label_list[label] = [line]
|
||||||
|
else:
|
||||||
|
label_list[label].append(line)
|
||||||
|
|
||||||
|
with open(os.path.join(setting_dir, "train.txt"), "w") as f:
|
||||||
|
file_list = []
|
||||||
|
for label in label_list:
|
||||||
|
for line in label_list[label][:k]: # train中每一类取前k个数据
|
||||||
|
f.writelines(json.dumps(line))
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def split_label_words(tokenizer, label_list):
|
||||||
|
label_word_list = []
|
||||||
|
for label in label_list:
|
||||||
|
if label == 'no_relation':
|
||||||
|
label_word_id = tokenizer.encode('None', add_special_tokens=False)
|
||||||
|
label_word_list.append(torch.tensor(label_word_id))
|
||||||
|
else:
|
||||||
|
tmps = label
|
||||||
|
label = label.lower()
|
||||||
|
label = label.split("(")[0]
|
||||||
|
label = label.replace(":"," ").replace("_"," ").replace("per","person").replace("org","organization")
|
||||||
|
label_word_id = tokenizer(label, add_special_tokens=False)['input_ids']
|
||||||
|
print(label, label_word_id)
|
||||||
|
label_word_list.append(torch.tensor(label_word_id))
|
||||||
|
padded_label_word_list = pad_sequence([x for x in label_word_list], batch_first=True, padding_value=0)
|
||||||
|
return padded_label_word_list
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--k", type=int, default=16,
|
||||||
|
help="Training examples for each class.")
|
||||||
|
# parser.add_argument("--task", type=str, nargs="+",
|
||||||
|
# default=['SST-2', 'sst-5', 'mr', 'cr', 'mpqa', 'subj', 'trec', 'CoLA', 'MRPC', 'QQP', 'STS-B', 'MNLI', 'SNLI', 'QNLI', 'RTE'],
|
||||||
|
# help="Task names")
|
||||||
|
parser.add_argument("--seed", type=int, nargs="+",
|
||||||
|
default=[1, 2, 3],
|
||||||
|
help="Random seeds")
|
||||||
|
|
||||||
|
parser.add_argument("--model_name_or_path", type=str, default="bert-large-uncased")
|
||||||
|
parser.add_argument("--dataset_name", type=str, default="semeval")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model_name_or_path = args.model_name_or_path
|
||||||
|
dataset_name = args.dataset_name
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
|
with open(f"dataset/{dataset_name}/rel2id.json", "r") as file:
|
||||||
|
t = json.load(file)
|
||||||
|
label_list = list(t)
|
||||||
|
|
||||||
|
t = split_label_words(tokenizer, label_list)
|
||||||
|
|
||||||
|
with open(f"./dataset/{model_name_or_path}_{dataset_name}.pt", "wb") as file:
|
||||||
|
torch.save(t, file)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
@ -0,0 +1,347 @@
|
||||||
|
"""Experiment-running framework."""
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
from logging import debug
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
import yaml
|
||||||
|
import time
|
||||||
|
from lit_models import TransformerLitModelTwoSteps
|
||||||
|
from transformers import AutoConfig, AutoModel
|
||||||
|
from transformers.optimization import get_linear_schedule_with_warmup
|
||||||
|
import os
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
# In order to ensure reproducible experiments, we must set random seeds.
|
||||||
|
|
||||||
|
|
||||||
|
def _import_class(module_and_class_name: str) -> type:
|
||||||
|
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
|
||||||
|
module_name, class_name = module_and_class_name.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
class_ = getattr(module, class_name)
|
||||||
|
return class_
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_parser():
|
||||||
|
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
|
||||||
|
parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
|
||||||
|
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
|
||||||
|
trainer_group = parser.add_argument_group("Trainer Args")
|
||||||
|
trainer_group.add_argument("--accelerator", default=None)
|
||||||
|
trainer_group.add_argument("--accumulate_grad_batches", default=1)
|
||||||
|
trainer_group.add_argument("--amp_backend", default='native')
|
||||||
|
trainer_group.add_argument("--amp_level", default='O2')
|
||||||
|
trainer_group.add_argument("--auto_lr_find", default=False)
|
||||||
|
trainer_group.add_argument("--auto_scale_batch_size", default=False)
|
||||||
|
trainer_group.add_argument("--auto_select_gpus", default=False)
|
||||||
|
trainer_group.add_argument("--benchmark", default=False)
|
||||||
|
trainer_group.add_argument("--check_val_every_n_epoch", default=1)
|
||||||
|
trainer_group.add_argument("--checkpoint_callback", default=True)
|
||||||
|
trainer_group.add_argument("--default_root_dir", default=None)
|
||||||
|
trainer_group.add_argument("--deterministic", default=False)
|
||||||
|
trainer_group.add_argument("--devices", default=None)
|
||||||
|
trainer_group.add_argument("--distributed_backend", default=None)
|
||||||
|
trainer_group.add_argument("--fast_dev_run", default=False)
|
||||||
|
trainer_group.add_argument("--flush_logs_every_n_steps", default=100)
|
||||||
|
trainer_group.add_argument("--gpus", default=None)
|
||||||
|
trainer_group.add_argument("--gradient_clip_algorithm", default='norm')
|
||||||
|
trainer_group.add_argument("--gradient_clip_val", default=0.0)
|
||||||
|
trainer_group.add_argument("--ipus", default=None)
|
||||||
|
trainer_group.add_argument("--limit_predict_batches", default=1.0)
|
||||||
|
trainer_group.add_argument("--limit_test_batches", default=1.0)
|
||||||
|
trainer_group.add_argument("--limit_train_batches", default=1.0)
|
||||||
|
trainer_group.add_argument("--limit_val_batches", default=1.0)
|
||||||
|
trainer_group.add_argument("--log_every_n_steps", default=50)
|
||||||
|
trainer_group.add_argument("--log_gpu_memory", default=None)
|
||||||
|
trainer_group.add_argument("--logger", default=True)
|
||||||
|
trainer_group.add_argument("--max_epochs", default=None)
|
||||||
|
trainer_group.add_argument("--max_steps", default=None)
|
||||||
|
trainer_group.add_argument("--max_time", default=None)
|
||||||
|
trainer_group.add_argument("--min_epochs", default=None)
|
||||||
|
trainer_group.add_argument("--min_steps", default=None)
|
||||||
|
trainer_group.add_argument("--move_metrics_to_cpu", default=False)
|
||||||
|
trainer_group.add_argument("--multiple_trainloader_mode", default='max_size_cycle')
|
||||||
|
trainer_group.add_argument("--num_nodes", default=1)
|
||||||
|
trainer_group.add_argument("--num_processes", default=1)
|
||||||
|
trainer_group.add_argument("--num_sanity_val_steps", default=2)
|
||||||
|
trainer_group.add_argument("--overfit_batches", default=0.0)
|
||||||
|
trainer_group.add_argument("--plugins", default=None)
|
||||||
|
trainer_group.add_argument("--precision", default=32)
|
||||||
|
trainer_group.add_argument("--prepare_data_per_node", default=True)
|
||||||
|
trainer_group.add_argument("--process_position", default=0)
|
||||||
|
trainer_group.add_argument("--profiler", default=None)
|
||||||
|
trainer_group.add_argument("--progress_bar_refresh_rate", default=None)
|
||||||
|
trainer_group.add_argument("--reload_dataloaders_every_epoch", default=False)
|
||||||
|
trainer_group.add_argument("--reload_dataloaders_every_n_epochs", default=0)
|
||||||
|
trainer_group.add_argument("--replace_sampler_ddp", default=True)
|
||||||
|
trainer_group.add_argument("--resume_from_checkpoint", default=None)
|
||||||
|
trainer_group.add_argument("--stochastic_weight_avg", default=False)
|
||||||
|
trainer_group.add_argument("--sync_batchnorm", default=False)
|
||||||
|
trainer_group.add_argument("--terminate_on_nan", default=False)
|
||||||
|
trainer_group.add_argument("--tpu_cores", default=None)
|
||||||
|
trainer_group.add_argument("--track_grad_norm", default=-1)
|
||||||
|
trainer_group.add_argument("--truncated_bptt_steps", default=None)
|
||||||
|
trainer_group.add_argument("--val_check_interval", default=1.0)
|
||||||
|
trainer_group.add_argument("--weights_save_path", default=None)
|
||||||
|
trainer_group.add_argument("--weights_summary", default='top')
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(add_help=False, parents=[parser])
|
||||||
|
|
||||||
|
# Basic arguments
|
||||||
|
parser.add_argument("--wandb", action="store_true", default=False)
|
||||||
|
parser.add_argument("--litmodel_class", type=str, default="TransformerLitModel")
|
||||||
|
parser.add_argument("--seed", type=int, default=666)
|
||||||
|
parser.add_argument("--data_class", type=str, default="DIALOGUE")
|
||||||
|
parser.add_argument("--lr_2", type=float, default=3e-5)
|
||||||
|
parser.add_argument("--model_class", type=str, default="bert.BertForSequenceClassification")
|
||||||
|
parser.add_argument("--two_steps", default=False, action="store_true")
|
||||||
|
parser.add_argument("--load_checkpoint", type=str, default=None)
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
|
||||||
|
parser.add_argument("--num_train_epochs", default=30, type=int)
|
||||||
|
parser.add_argument("--log_dir", default='', type=str)
|
||||||
|
parser.add_argument("--save_path", default='', type=str)
|
||||||
|
parser.add_argument("--train_from_saved_model", default='', type=str)
|
||||||
|
|
||||||
|
# Get the data and model classes, so that we can add their specific arguments
|
||||||
|
temp_args, _ = parser.parse_known_args()
|
||||||
|
data_class = _import_class(f"data.{temp_args.data_class}")
|
||||||
|
model_class = _import_class(f"models.{temp_args.model_class}")
|
||||||
|
litmodel_class = _import_class(f"lit_models.{temp_args.litmodel_class}")
|
||||||
|
|
||||||
|
# Get data, model, and LitModel specific arguments
|
||||||
|
data_group = parser.add_argument_group("Data Args")
|
||||||
|
data_class.add_to_argparse(data_group)
|
||||||
|
|
||||||
|
model_group = parser.add_argument_group("Model Args")
|
||||||
|
model_class.add_to_argparse(model_group)
|
||||||
|
|
||||||
|
lit_model_group = parser.add_argument_group("LitModel Args")
|
||||||
|
litmodel_class.add_to_argparse(lit_model_group)
|
||||||
|
|
||||||
|
parser.add_argument("--help", "-h", action="help")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
def _get_relation_embedding(data):
|
||||||
|
train_dataloader = data.train_dataloader()
|
||||||
|
#! hard coded
|
||||||
|
relation_embedding = [[] for _ in range(36)]
|
||||||
|
model = AutoModel.from_pretrained('bert-base-uncased')
|
||||||
|
model.eval()
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
for batch in tqdm(train_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
#! why the sample in this case will cause errors
|
||||||
|
if cnt == 416:
|
||||||
|
continue
|
||||||
|
cnt += 1
|
||||||
|
input_ids, attention_mask, token_type_ids , labels = batch
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
attention_mask = attention_mask.to(device)
|
||||||
|
token_type_ids = token_type_ids.to(device)
|
||||||
|
|
||||||
|
logits = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).last_hidden_state.detach().cpu()
|
||||||
|
_, mask_idx = (input_ids == 103).nonzero(as_tuple=True)
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
mask_output = logits[torch.arange(bs), mask_idx] # [batch_size, hidden_size]
|
||||||
|
|
||||||
|
|
||||||
|
labels = labels.detach().cpu()
|
||||||
|
mask_output = mask_output.detach().cpu()
|
||||||
|
assert len(labels[0]) == len(relation_embedding)
|
||||||
|
for batch_idx, label in enumerate(labels.tolist()):
|
||||||
|
for i, x in enumerate(label):
|
||||||
|
if x:
|
||||||
|
relation_embedding[i].append(mask_output[batch_idx])
|
||||||
|
|
||||||
|
# get the mean pooling
|
||||||
|
for i in range(36):
|
||||||
|
if len(relation_embedding[i]):
|
||||||
|
relation_embedding[i] = torch.mean(torch.stack(relation_embedding[i]), dim=0)
|
||||||
|
else:
|
||||||
|
relation_embedding[i] = torch.rand_like(relation_embedding[i-1])
|
||||||
|
|
||||||
|
del model
|
||||||
|
return relation_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def logging(log_dir, s, print_=True, log_=True):
|
||||||
|
if print_:
|
||||||
|
print(s)
|
||||||
|
if log_dir != '' and log_:
|
||||||
|
with open(log_dir, 'a+') as f_log:
|
||||||
|
f_log.write(s + '\n')
|
||||||
|
|
||||||
|
def test(args, model, lit_model, data):
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
test_loss = []
|
||||||
|
for test_index, test_batch in enumerate(tqdm(data.test_dataloader())):
|
||||||
|
loss = lit_model.test_step(test_batch, test_index)
|
||||||
|
test_loss.append(loss)
|
||||||
|
f1 = lit_model.test_epoch_end(test_loss)
|
||||||
|
logging(args.log_dir,
|
||||||
|
'| test_result: {}'.format(f1))
|
||||||
|
logging(args.log_dir,'-' * 89)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = _setup_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
#pl.seed_everything(args.seed)
|
||||||
|
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
data_class = _import_class(f"data.{args.data_class}")
|
||||||
|
model_class = _import_class(f"models.{args.model_class}")
|
||||||
|
litmodel_class = _import_class(f"lit_models.{args.litmodel_class}")
|
||||||
|
|
||||||
|
data = data_class(args)
|
||||||
|
data_config = data.get_data_config()
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||||
|
config.num_labels = data_config["num_labels"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# gpt no config?
|
||||||
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||||
|
|
||||||
|
if args.train_from_saved_model != '':
|
||||||
|
#model.load_state_dict(torch.load(args.train_from_saved_model)["checkpoint"])
|
||||||
|
print("load saved model from {}.".format(args.train_from_saved_model))
|
||||||
|
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
|
model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
cur_model = model.module if hasattr(model, 'module') else model
|
||||||
|
|
||||||
|
|
||||||
|
if "gpt" in args.model_name_or_path or "roberta" in args.model_name_or_path:
|
||||||
|
tokenizer = data.get_tokenizer()
|
||||||
|
cur_model.resize_token_embeddings(len(tokenizer))
|
||||||
|
cur_model.update_word_idx(len(tokenizer))
|
||||||
|
if "Use" in args.model_class:
|
||||||
|
continous_prompt = [a[0] for a in tokenizer([f"[T{i}]" for i in range(1,3)], add_special_tokens=False)['input_ids']]
|
||||||
|
continous_label_word = [a[0] for a in tokenizer([f"[class{i}]" for i in range(1, data.num_labels+1)], add_special_tokens=False)['input_ids']]
|
||||||
|
discrete_prompt = [a[0] for a in tokenizer(['It', 'was'], add_special_tokens=False)['input_ids']]
|
||||||
|
dataset_name = args.data_dir.split("/")[1]
|
||||||
|
model.init_unused_weights(continous_prompt, continous_label_word, discrete_prompt, label_path=f"{args.model_name_or_path}_{dataset_name}.pt")
|
||||||
|
# data.setup()
|
||||||
|
# relation_embedding = _get_relation_embedding(data)
|
||||||
|
lit_model = litmodel_class(args=args, model=model, tokenizer=data.tokenizer, device=device)
|
||||||
|
if args.train_from_saved_model != '':
|
||||||
|
lit_model.best_f1 = torch.load(args.train_from_saved_model)["best_f1"]
|
||||||
|
data.tokenizer.save_pretrained('test')
|
||||||
|
data.setup()
|
||||||
|
|
||||||
|
optimizer = lit_model.configure_optimizers()
|
||||||
|
if args.train_from_saved_model != '':
|
||||||
|
optimizer.load_state_dict(torch.load(args.train_from_saved_model)["optimizer"])
|
||||||
|
print("load saved optimizer from {}.".format(args.train_from_saved_model))
|
||||||
|
|
||||||
|
num_training_steps = len(data.train_dataloader()) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)
|
||||||
|
log_step = 100
|
||||||
|
|
||||||
|
|
||||||
|
logging(args.log_dir,'-' * 89, print_=False)
|
||||||
|
logging(args.log_dir, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' INFO : START TO TRAIN ', print_=False)
|
||||||
|
logging(args.log_dir,'-' * 89, print_=False)
|
||||||
|
|
||||||
|
for epoch in range(args.num_train_epochs):
|
||||||
|
model.train()
|
||||||
|
num_batch = len(data.train_dataloader())
|
||||||
|
total_loss = 0
|
||||||
|
log_loss = 0
|
||||||
|
for index, train_batch in enumerate(tqdm(data.train_dataloader())):
|
||||||
|
loss = lit_model.training_step(train_batch, index)
|
||||||
|
total_loss += loss.item()
|
||||||
|
log_loss += loss.item()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if log_step > 0 and (index+1) % log_step == 0:
|
||||||
|
cur_loss = log_loss / log_step
|
||||||
|
logging(args.log_dir,
|
||||||
|
'| epoch {:2d} | step {:4d} | lr {} | train loss {:5.3f}'.format(
|
||||||
|
epoch, (index+1), scheduler.get_last_lr(), cur_loss * 1000)
|
||||||
|
, print_=False)
|
||||||
|
log_loss = 0
|
||||||
|
avrg_loss = total_loss / num_batch
|
||||||
|
logging(args.log_dir,
|
||||||
|
'| epoch {:2d} | train loss {:5.3f}'.format(
|
||||||
|
epoch, avrg_loss * 1000))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
val_loss = []
|
||||||
|
for val_index, val_batch in enumerate(tqdm(data.val_dataloader())):
|
||||||
|
loss = lit_model.validation_step(val_batch, val_index)
|
||||||
|
val_loss.append(loss)
|
||||||
|
f1, best, best_f1 = lit_model.validation_epoch_end(val_loss)
|
||||||
|
logging(args.log_dir,'-' * 89)
|
||||||
|
logging(args.log_dir,
|
||||||
|
'| epoch {:2d} | dev_result: {}'.format(epoch, f1))
|
||||||
|
logging(args.log_dir,'-' * 89)
|
||||||
|
logging(args.log_dir,
|
||||||
|
'| best_f1: {}'.format(best_f1))
|
||||||
|
logging(args.log_dir,'-' * 89)
|
||||||
|
if args.save_path != "" and best != -1:
|
||||||
|
file_name = f"{epoch}-Eval_f1-{best_f1:.2f}.pt"
|
||||||
|
save_path = args.save_path + '/' + file_name
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch,
|
||||||
|
'checkpoint': cur_model.state_dict(),
|
||||||
|
'best_f1': best_f1,
|
||||||
|
'optimizer': optimizer.state_dict()
|
||||||
|
}, save_path
|
||||||
|
, _use_new_zipfile_serialization=False)
|
||||||
|
logging(args.log_dir,
|
||||||
|
'| successfully save model at: {}'.format(save_path))
|
||||||
|
logging(args.log_dir,'-' * 89)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
path = args.save_path + '/config'
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.mkdir(path)
|
||||||
|
config_file_name = time.strftime("%H:%M:%S", time.localtime()) + ".yaml"
|
||||||
|
day_name = time.strftime("%Y-%m-%d")
|
||||||
|
if not os.path.exists(os.path.join(path, day_name)):
|
||||||
|
os.mkdir(os.path.join(path, day_name))
|
||||||
|
config = vars(args)
|
||||||
|
config["path"] = path
|
||||||
|
with open(os.path.join(os.path.join(path, day_name), config_file_name), "w") as file:
|
||||||
|
file.write(yaml.dump(config))
|
||||||
|
|
||||||
|
test(args, model, lit_model, data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
main()
|
|
@ -0,0 +1,11 @@
|
||||||
|
numpy==1.20.3
|
||||||
|
tokenizers==0.10.3
|
||||||
|
torch==1.8.0
|
||||||
|
regex==2021.4.4
|
||||||
|
transformers==4.7.0
|
||||||
|
tqdm==4.49.0
|
||||||
|
activations==0.1.0
|
||||||
|
dataclasses==0.6
|
||||||
|
file_utils==0.0.1
|
||||||
|
flax==0.3.4
|
||||||
|
utils==1.0.1
|
|
@ -0,0 +1 @@
|
||||||
|
# this is an empty file
|
|
@ -0,0 +1,16 @@
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
|
||||||
|
python main.py --num_train_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-base-uncased \
|
||||||
|
--accumulate_grad_batches 4 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--data_dir dataset/dialogue \
|
||||||
|
--check_val_every_n_epoch 1 \
|
||||||
|
--data_class DIALOGUE \
|
||||||
|
--max_seq_length 512 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--litmodel_class DialogueLitModel \
|
||||||
|
--task_name normal \
|
||||||
|
--lr 3e-5 \
|
||||||
|
--log_dir ./logs/dialogue.log \
|
||||||
|
--save_path ./saved_models
|
|
@ -0,0 +1,66 @@
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
python main.py --max_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--data_dir dataset/semeval/k-shot/8-1 \
|
||||||
|
--check_val_every_n_epoch 3 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 256 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/semeval_k-shot_8-1.log \
|
||||||
|
--save_path ./saved_models
|
||||||
|
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
python main.py --max_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--data_dir dataset/semeval/k-shot/16-1 \
|
||||||
|
--check_val_every_n_epoch 3 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 256 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/semeval_k-shot_16-1.log \
|
||||||
|
--save_path ./saved_models
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
python main.py --max_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--data_dir dataset/semeval/k-shot/32-1 \
|
||||||
|
--check_val_every_n_epoch 3 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 256 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/semeval_k-shot_32-1.log \
|
||||||
|
--save_path ./saved_models
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
python main.py --max_epochs=5 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 16 \
|
||||||
|
--data_dir dataset/semeval \
|
||||||
|
--check_val_every_n_epoch 1 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 256 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/semeval.log \
|
||||||
|
--save_path ./saved_models
|
|
@ -0,0 +1,64 @@
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
python main.py --max_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--data_dir dataset/tacrev/k-shot/8-1 \
|
||||||
|
--check_val_every_n_epoch 2 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 512 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/tacrev_k-shot_8-1.log \
|
||||||
|
--save_path ./saved_models
|
||||||
|
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
python main.py --max_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--data_dir dataset/tacrev/k-shot/16-1 \
|
||||||
|
--check_val_every_n_epoch 2 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 512 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/tacrev_k-shot_16-1.log \
|
||||||
|
--save_path ./saved_models
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
python main.py --max_epochs=30 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--data_dir dataset/tacrev/k-shot/32-1 \
|
||||||
|
--check_val_every_n_epoch 2 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 512 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/tacrev_k-shot_32-1.log \
|
||||||
|
--save_path ./saved_models
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||||
|
python main.py --max_epochs=5 --num_workers=8 \
|
||||||
|
--model_name_or_path bert-large-uncased \
|
||||||
|
--accumulate_grad_batches 1 \
|
||||||
|
--batch_size 8 \
|
||||||
|
--data_dir dataset/tacrev \
|
||||||
|
--check_val_every_n_epoch 1 \
|
||||||
|
--data_class REDataset \
|
||||||
|
--max_seq_length 512 \
|
||||||
|
--model_class BertForMaskedLM \
|
||||||
|
--t_lambda 0.001 \
|
||||||
|
--litmodel_class BertLitModel \
|
||||||
|
--lr 3e-5
|
||||||
|
--log_dir ./logs/tacrev.log \
|
||||||
|
--save_path ./saved_models
|
|
@ -0,0 +1 @@
|
||||||
|
{"[speaker24]": 30581, "[speaker3]": 30560, "[class4]": 30525, "[speaker15]": 30572, "[class10]": 30531, "[speaker31]": 30588, "[speaker16]": 30573, "[speaker38]": 30595, "[class36]": 30557, "[speaker37]": 30594, "[class5]": 30526, "[speaker30]": 30587, "[speaker39]": 30596, "[speaker22]": 30579, "[class7]": 30528, "[class24]": 30545, "[class28]": 30549, "[speaker36]": 30593, "[speaker35]": 30592, "[speaker26]": 30583, "[speaker4]": 30561, "[class22]": 30543, "[speaker11]": 30568, "[class14]": 30535, "[class32]": 30553, "[class16]": 30537, "[speaker6]": 30563, "[class35]": 30556, "[class21]": 30542, "[class30]": 30551, "[class31]": 30552, "[speaker2]": 30559, "[class23]": 30544, "[speaker12]": 30569, "[speaker40]": 30597, "[class20]": 30541, "[speaker23]": 30580, "[speaker14]": 30571, "[speaker21]": 30578, "[speaker43]": 30600, "[class19]": 30540, "[class3]": 30524, "[speaker29]": 30586, "[speaker19]": 30576, "[speaker8]": 30565, "[speaker42]": 30599, "[speaker28]": 30585, "[speaker7]": 30564, "[class33]": 30554, "[class15]": 30536, "[speaker44]": 30601, "[class26]": 30547, "[class13]": 30534, "[speaker25]": 30582, "[speaker33]": 30590, "[speaker48]": 30605, "[sub]": 30607, "[class27]": 30548, "[speaker47]": 30604, "[speaker5]": 30562, "[speaker27]": 30584, "[class6]": 30527, "[class2]": 30523, "[speaker18]": 30575, "[speaker13]": 30570, "[class12]": 30533, "[class18]": 30539, "[class11]": 30532, "[speaker46]": 30603, "[speaker32]": 30589, "[speaker10]": 30567, "[speaker20]": 30577, "[speaker1]": 30558, "[speaker34]": 30591, "[obj]": 30608, "[speaker17]": 30574, "[speaker49]": 30606, "[class8]": 30529, "[class34]": 30555, "[class17]": 30538, "[speaker41]": 30598, "[class29]": 30550, "[speaker9]": 30566, "[class9]": 30530, "[class25]": 30546, "[class1]": 30522, "[speaker45]": 30602}
|
|
@ -0,0 +1 @@
|
||||||
|
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "additional_special_tokens": ["[sub]", "[obj]"]}
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1 @@
|
||||||
|
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "bert-base-uncased", "tokenizer_class": "BertTokenizer"}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue