Add files via upload
This commit is contained in:
parent
c325dd7112
commit
a26174d092
|
@ -0,0 +1,57 @@
|
|||
## 快速上手
|
||||
|
||||
### 环境依赖
|
||||
|
||||
> python == 3.8
|
||||
|
||||
- torch == 1.8.1
|
||||
- transformers == 4.7.0
|
||||
- opt-einsum == 3.3.0
|
||||
- ujson
|
||||
- deepke
|
||||
|
||||
### 克隆代码
|
||||
```
|
||||
git clone git@github.com:zjunlp/DeepKE.git
|
||||
```
|
||||
### 使用pip安装
|
||||
|
||||
首先创建python虚拟环境,再进入虚拟环境
|
||||
|
||||
- 安装依赖: ```pip install -r requirements.txt```
|
||||
|
||||
### 使用数据进行训练预测
|
||||
|
||||
- 存放数据:在 `data` 文件夹下存放训练数据。模型采用的数据集是[DocRED](https://github.com/thunlp/DocRED/tree/master/),DocRED数据集来自于2010年的国际语义评测大会中Task 8:"Multi-Way Classification of Semantic Relations Between Pairs of Nominals"。
|
||||
|
||||
- DocRED包含以下数据:
|
||||
|
||||
- `dev.json`:验证集
|
||||
|
||||
- `rel_info.json`:关系集
|
||||
|
||||
- `rel2id.json`:关系标签到ID的映射
|
||||
|
||||
- `test.json`:测试集
|
||||
|
||||
- `train_annotated.json`:训练集
|
||||
|
||||
- `train_distant.json`: 由于文件太大,请自行从[Google Drive](https://drive.google.com/drive/folders/1c5-0YwnoJx8NS6CV2f-NoTHR__BdkNqw)上下载到`data/`目录下
|
||||
|
||||
- 开始训练:模型加载和保存位置以及配置可以在conf的`.yaml`文件中修改
|
||||
|
||||
- 在数据集DocRED中训练:`python run.py`
|
||||
|
||||
- 训练好的模型保存在根目录下
|
||||
|
||||
- 从上次训练的模型开始训练:设置`.yaml`中的train_from_saved_model为上次保存模型的路径
|
||||
|
||||
- 每次训练的日志保存路径默认保存在根目录,可以通过`.yaml`中的log_dir来配置
|
||||
|
||||
- 进行预测: `python predict.py`
|
||||
|
||||
- 预测生成的`result.json`保存在根目录
|
||||
|
||||
|
||||
## 模型内容
|
||||
DocuNet
|
|
@ -0,0 +1,3 @@
|
|||
defaults:
|
||||
- hydra/output: custom
|
||||
- train
|
|
@ -0,0 +1,11 @@
|
|||
hydra:
|
||||
|
||||
run:
|
||||
# Output directory for normal runs
|
||||
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||
|
||||
sweep:
|
||||
# Output directory for sweep runs
|
||||
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||
# Output sub directory for sweep runs.
|
||||
subdir: ${hydra.job.num}_${hydra.job.id}
|
|
@ -0,0 +1,32 @@
|
|||
adam_epsilon: 1e-06
|
||||
bert_lr: 3e-05
|
||||
channel_type: 'context-based'
|
||||
config_name: ''
|
||||
data_dir: 'data'
|
||||
dataset: 'docred'
|
||||
dev_file: 'dev.json'
|
||||
down_dim: 256
|
||||
evaluation_steps: -1
|
||||
gradient_accumulation_steps: 2
|
||||
learning_rate: 0.0004
|
||||
log_dir: './train_roberta.log'
|
||||
max_grad_norm: 1.0
|
||||
max_height: 42
|
||||
max_seq_length: 1024
|
||||
model_name_or_path: 'roberta-base'
|
||||
num_class: 97
|
||||
num_labels: 4
|
||||
num_train_epochs: 30
|
||||
save_path: './model_roberta.pt'
|
||||
seed: 111
|
||||
test_batch_size: 2
|
||||
test_file: 'test.json'
|
||||
tokenizer_name: ''
|
||||
train_batch_size: 2
|
||||
train_file: 'train_annotated.json'
|
||||
train_from_saved_model: ''
|
||||
transformer_type: 'roberta'
|
||||
unet_in_dim: 3
|
||||
unet_out_dim: 256
|
||||
warmup_ratio: 0.06
|
||||
load_path: './model_roberta.pt'
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"P1376": 79,
|
||||
"P607": 27,
|
||||
"P136": 73,
|
||||
"P137": 63,
|
||||
"P131": 2,
|
||||
"P527": 11,
|
||||
"P1412": 38,
|
||||
"P206": 33,
|
||||
"P205": 77,
|
||||
"P449": 52,
|
||||
"P127": 34,
|
||||
"P123": 49,
|
||||
"P86": 66,
|
||||
"P840": 85,
|
||||
"P355": 72,
|
||||
"P737": 93,
|
||||
"P740": 84,
|
||||
"P190": 94,
|
||||
"P576": 71,
|
||||
"P749": 68,
|
||||
"P112": 65,
|
||||
"P118": 40,
|
||||
"P17": 1,
|
||||
"P19": 14,
|
||||
"P3373": 19,
|
||||
"P6": 42,
|
||||
"P276": 44,
|
||||
"P1001": 24,
|
||||
"P580": 62,
|
||||
"P582": 83,
|
||||
"P585": 64,
|
||||
"P463": 18,
|
||||
"P676": 87,
|
||||
"P674": 46,
|
||||
"P264": 10,
|
||||
"P108": 43,
|
||||
"P102": 17,
|
||||
"P25": 81,
|
||||
"P27": 3,
|
||||
"P26": 26,
|
||||
"P20": 37,
|
||||
"P22": 30,
|
||||
"Na": 0,
|
||||
"P807": 95,
|
||||
"P800": 51,
|
||||
"P279": 78,
|
||||
"P1336": 88,
|
||||
"P577": 5,
|
||||
"P570": 8,
|
||||
"P571": 15,
|
||||
"P178": 36,
|
||||
"P179": 55,
|
||||
"P272": 75,
|
||||
"P170": 35,
|
||||
"P171": 80,
|
||||
"P172": 76,
|
||||
"P175": 6,
|
||||
"P176": 67,
|
||||
"P39": 91,
|
||||
"P30": 21,
|
||||
"P31": 60,
|
||||
"P36": 70,
|
||||
"P37": 58,
|
||||
"P35": 54,
|
||||
"P400": 31,
|
||||
"P403": 61,
|
||||
"P361": 12,
|
||||
"P364": 74,
|
||||
"P569": 7,
|
||||
"P710": 41,
|
||||
"P1344": 32,
|
||||
"P488": 82,
|
||||
"P241": 59,
|
||||
"P162": 57,
|
||||
"P161": 9,
|
||||
"P166": 47,
|
||||
"P40": 20,
|
||||
"P1441": 23,
|
||||
"P156": 45,
|
||||
"P155": 39,
|
||||
"P150": 4,
|
||||
"P551": 90,
|
||||
"P706": 56,
|
||||
"P159": 29,
|
||||
"P495": 13,
|
||||
"P58": 53,
|
||||
"P194": 48,
|
||||
"P54": 16,
|
||||
"P57": 28,
|
||||
"P50": 22,
|
||||
"P1366": 86,
|
||||
"P1365": 92,
|
||||
"P937": 69,
|
||||
"P140": 50,
|
||||
"P69": 25,
|
||||
"P1198": 96,
|
||||
"P1056": 89
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
{"P6": "head of government", "P17": "country", "P19": "place of birth", "P20": "place of death", "P22": "father", "P25": "mother", "P26": "spouse", "P27": "country of citizenship", "P30": "continent", "P31": "instance of", "P35": "head of state", "P36": "capital", "P37": "official language", "P39": "position held", "P40": "child", "P50": "author", "P54": "member of sports team", "P57": "director", "P58": "screenwriter", "P69": "educated at", "P86": "composer", "P102": "member of political party", "P108": "employer", "P112": "founded by", "P118": "league", "P123": "publisher", "P127": "owned by", "P131": "located in the administrative territorial entity", "P136": "genre", "P137": "operator", "P140": "religion", "P150": "contains administrative territorial entity", "P155": "follows", "P156": "followed by", "P159": "headquarters location", "P161": "cast member", "P162": "producer", "P166": "award received", "P170": "creator", "P171": "parent taxon", "P172": "ethnic group", "P175": "performer", "P176": "manufacturer", "P178": "developer", "P179": "series", "P190": "sister city", "P194": "legislative body", "P205": "basin country", "P206": "located in or next to body of water", "P241": "military branch", "P264": "record label", "P272": "production company", "P276": "location", "P279": "subclass of", "P355": "subsidiary", "P361": "part of", "P364": "original language of work", "P400": "platform", "P403": "mouth of the watercourse", "P449": "original network", "P463": "member of", "P488": "chairperson", "P495": "country of origin", "P527": "has part", "P551": "residence", "P569": "date of birth", "P570": "date of death", "P571": "inception", "P576": "dissolved, abolished or demolished", "P577": "publication date", "P580": "start time", "P582": "end time", "P585": "point in time", "P607": "conflict", "P674": "characters", "P676": "lyrics by", "P706": "located on terrain feature", "P710": "participant", "P737": "influenced by", "P740": "location of formation", "P749": "parent organization", "P800": "notable work", "P807": "separated from", "P840": "narrative location", "P937": "work location", "P1001": "applies to jurisdiction", "P1056": "product or material produced", "P1198": "unemployment rate", "P1336": "territory claimed by", "P1344": "participant of", "P1365": "replaces", "P1366": "replaced by", "P1376": "capital of", "P1412": "languages spoken, written or signed", "P1441": "present in work", "P3373": "sibling"}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
import time
|
||||
import hydra
|
||||
from hydra.utils import get_original_cwd
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import ujson as json
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from deepkeredoc import *
|
||||
|
||||
|
||||
def report(args, model, features):
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
|
||||
preds = []
|
||||
for batch in dataloader:
|
||||
model.eval()
|
||||
|
||||
inputs = {'input_ids': batch[0].to(device),
|
||||
'attention_mask': batch[1].to(device),
|
||||
'entity_pos': batch[3],
|
||||
'hts': batch[4],
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
pred = model(**inputs)
|
||||
pred = pred.cpu().numpy()
|
||||
pred[np.isnan(pred)] = 0
|
||||
preds.append(pred)
|
||||
|
||||
preds = np.concatenate(preds, axis=0).astype(np.float32)
|
||||
preds = to_official(args, preds, features)
|
||||
return preds
|
||||
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/config.yaml")
|
||||
def main(cfg):
|
||||
cwd = get_original_cwd()
|
||||
os.chdir(cwd)
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
cfg.config_name if cfg.config_name else cfg.model_name_or_path,
|
||||
num_labels=cfg.num_class,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
cfg.tokenizer_name if cfg.tokenizer_name else cfg.model_name_or_path,
|
||||
)
|
||||
|
||||
Dataset = ReadDataset(cfg, cfg.dataset, tokenizer, cfg.max_seq_length)
|
||||
|
||||
|
||||
test_file = os.path.join(cfg.data_dir, cfg.test_file)
|
||||
|
||||
test_features = Dataset.read(test_file)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
cfg.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in cfg.model_name_or_path),
|
||||
config=config,
|
||||
)
|
||||
|
||||
config.cls_token_id = tokenizer.cls_token_id
|
||||
config.sep_token_id = tokenizer.sep_token_id
|
||||
config.transformer_type = cfg.transformer_type
|
||||
|
||||
set_seed(cfg)
|
||||
model = DocREModel(config, cfg, model, num_labels=cfg.num_labels)
|
||||
|
||||
|
||||
model.load_state_dict(torch.load(cfg.load_path)['checkpoint'])
|
||||
model.to(device)
|
||||
T_features = test_features # Testing on the test set
|
||||
#T_score, T_output = evaluate(cfg, model, T_features, tag="test")
|
||||
pred = report(cfg, model, T_features)
|
||||
with open("./result.json", "w") as fh:
|
||||
json.dump(pred, fh)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,5 @@
|
|||
torch==1.8.1
|
||||
transformers==4.7.0
|
||||
opt-einsum==3.3.0
|
||||
hydra-core==1.0.6
|
||||
ujson
|
|
@ -0,0 +1,231 @@
|
|||
import os
|
||||
import time
|
||||
import hydra
|
||||
from hydra.utils import get_original_cwd
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import ujson as json
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
from deepkeredoc import *
|
||||
|
||||
|
||||
def train(args, model, train_features, dev_features, test_features):
|
||||
def logging(s, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_ and args.log_dir != '':
|
||||
with open(args.log_dir, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
def finetune(features, optimizer, num_epoch, num_steps, model):
|
||||
cur_model = model.module if hasattr(model, 'module') else model
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
if args.train_from_saved_model != '':
|
||||
best_score = torch.load(args.train_from_saved_model)["best_f1"]
|
||||
epoch_delta = torch.load(args.train_from_saved_model)["epoch"] + 1
|
||||
else:
|
||||
epoch_delta = 0
|
||||
best_score = -1
|
||||
train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
|
||||
train_iterator = [epoch + epoch_delta for epoch in range(num_epoch)]
|
||||
total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps)
|
||||
warmup_steps = int(total_steps * args.warmup_ratio)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
|
||||
print("Total steps: {}".format(total_steps))
|
||||
print("Warmup steps: {}".format(warmup_steps))
|
||||
global_step = 0
|
||||
log_step = 100
|
||||
total_loss = 0
|
||||
|
||||
|
||||
|
||||
#scaler = GradScaler()
|
||||
for epoch in train_iterator:
|
||||
start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
model.train()
|
||||
|
||||
inputs = {'input_ids': batch[0].to(device),
|
||||
'attention_mask': batch[1].to(device),
|
||||
'labels': batch[2],
|
||||
'entity_pos': batch[3],
|
||||
'hts': batch[4],
|
||||
}
|
||||
#with autocast():
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] / args.gradient_accumulation_steps
|
||||
total_loss += loss.item()
|
||||
# scaler.scale(loss).backward()
|
||||
|
||||
|
||||
loss.backward()
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
#scaler.unscale_(optimizer)
|
||||
if args.max_grad_norm > 0:
|
||||
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(cur_model.parameters(), args.max_grad_norm)
|
||||
#scaler.step(optimizer)
|
||||
#scaler.update()
|
||||
#scheduler.step()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
num_steps += 1
|
||||
if global_step % log_step == 0:
|
||||
cur_loss = total_loss / log_step
|
||||
elapsed = time.time() - start_time
|
||||
logging(
|
||||
'| epoch {:2d} | step {:4d} | min/b {:5.2f} | lr {} | train loss {:5.3f}'.format(
|
||||
epoch, global_step, elapsed / 60, scheduler.get_last_lr(), cur_loss * 1000))
|
||||
total_loss = 0
|
||||
start_time = time.time()
|
||||
|
||||
if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0):
|
||||
# if step ==0:
|
||||
logging('-' * 89)
|
||||
eval_start_time = time.time()
|
||||
dev_score, dev_output = evaluate(args, model, dev_features, tag="dev")
|
||||
|
||||
logging(
|
||||
'| epoch {:3d} | time: {:5.2f}s | dev_result:{}'.format(epoch, time.time() - eval_start_time,
|
||||
dev_output))
|
||||
logging('-' * 89)
|
||||
if dev_score > best_score:
|
||||
best_score = dev_score
|
||||
logging(
|
||||
'| epoch {:3d} | best_f1:{}'.format(epoch, best_score))
|
||||
if args.save_path != "":
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'checkpoint': cur_model.state_dict(),
|
||||
'best_f1': best_score,
|
||||
'optimizer': optimizer.state_dict()
|
||||
}, args.save_path
|
||||
, _use_new_zipfile_serialization=False)
|
||||
logging(
|
||||
'| successfully save model at: {}'.format(args.save_path))
|
||||
logging('-' * 89)
|
||||
return num_steps
|
||||
|
||||
cur_model = model.module if hasattr(model, 'module') else model
|
||||
extract_layer = ["extractor", "bilinear"]
|
||||
bert_layer = ['bert_model']
|
||||
optimizer_grouped_parameters = [
|
||||
{"params": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in bert_layer)], "lr": args.bert_lr},
|
||||
{"params": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in extract_layer)], "lr": 1e-4},
|
||||
{"params": [p for n, p in cur_model.named_parameters() if not any(nd in n for nd in extract_layer + bert_layer)]},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
if args.train_from_saved_model != '':
|
||||
optimizer.load_state_dict(torch.load(args.train_from_saved_model)["optimizer"])
|
||||
print("load saved optimizer from {}.".format(args.train_from_saved_model))
|
||||
|
||||
|
||||
num_steps = 0
|
||||
set_seed(args)
|
||||
model.zero_grad()
|
||||
finetune(train_features, optimizer, args.num_train_epochs, num_steps, model)
|
||||
|
||||
|
||||
def evaluate(args, model, features, tag="dev"):
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
|
||||
preds = []
|
||||
total_loss = 0
|
||||
for i, batch in enumerate(dataloader):
|
||||
model.eval()
|
||||
|
||||
inputs = {'input_ids': batch[0].to(device),
|
||||
'attention_mask': batch[1].to(device),
|
||||
'labels': batch[2],
|
||||
'entity_pos': batch[3],
|
||||
'hts': batch[4],
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
loss = output[0]
|
||||
pred = output[1].cpu().numpy()
|
||||
pred[np.isnan(pred)] = 0
|
||||
preds.append(pred)
|
||||
total_loss += loss.item()
|
||||
|
||||
average_loss = total_loss / (i + 1)
|
||||
preds = np.concatenate(preds, axis=0).astype(np.float32)
|
||||
ans = to_official(args, preds, features)
|
||||
if len(ans) > 0:
|
||||
best_f1, _, best_f1_ign, _, re_p, re_r = official_evaluate(ans, args.data_dir)
|
||||
output = {
|
||||
tag + "_F1": best_f1 * 100,
|
||||
tag + "_F1_ign": best_f1_ign * 100,
|
||||
tag + "_re_p": re_p * 100,
|
||||
tag + "_re_r": re_r * 100,
|
||||
tag + "_average_loss": average_loss
|
||||
}
|
||||
return best_f1, output
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path="conf/config.yaml")
|
||||
def main(cfg):
|
||||
cwd = get_original_cwd()
|
||||
os.chdir(cwd)
|
||||
|
||||
if not os.path.exists(os.path.join(cfg.data_dir, "train_distant.json")):
|
||||
raise FileNotFoundError("Sorry, the file: 'train_annotated.json' is too big to upload to github, \
|
||||
please manually download to 'data/' from DocRED GoogleDrive https://drive.google.com/drive/folders/1c5-0YwnoJx8NS6CV2f-NoTHR__BdkNqw")
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
cfg.config_name if cfg.config_name else cfg.model_name_or_path,
|
||||
num_labels=cfg.num_class,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
cfg.tokenizer_name if cfg.tokenizer_name else cfg.model_name_or_path,
|
||||
)
|
||||
|
||||
Dataset = ReadDataset(cfg, cfg.dataset, tokenizer, cfg.max_seq_length)
|
||||
|
||||
train_file = os.path.join(cfg.data_dir, cfg.train_file)
|
||||
dev_file = os.path.join(cfg.data_dir, cfg.dev_file)
|
||||
test_file = os.path.join(cfg.data_dir, cfg.test_file)
|
||||
train_features = Dataset.read(train_file)
|
||||
dev_features = Dataset.read(dev_file)
|
||||
test_features = Dataset.read(test_file)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
cfg.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in cfg.model_name_or_path),
|
||||
config=config,
|
||||
)
|
||||
|
||||
config.cls_token_id = tokenizer.cls_token_id
|
||||
config.sep_token_id = tokenizer.sep_token_id
|
||||
config.transformer_type = cfg.transformer_type
|
||||
|
||||
set_seed(cfg)
|
||||
model = DocREModel(config, cfg, model, num_labels=cfg.num_labels)
|
||||
if cfg.train_from_saved_model != '':
|
||||
model.load_state_dict(torch.load(cfg.train_from_saved_model)["checkpoint"])
|
||||
print("load saved model from {}.".format(cfg.train_from_saved_model))
|
||||
|
||||
#if torch.cuda.device_count() > 1:
|
||||
# print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
# model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))
|
||||
model.to(device)
|
||||
|
||||
train(cfg, model, train_features, dev_features, test_features)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue