deepke/main.py

95 lines
3.0 KiB
Python
Raw Normal View History

2019-08-20 21:25:34 +08:00
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from deepke.config import config
from deepke import model
from deepke.utils import make_seed, load_pkl
from deepke.trainer import train, validate
2019-09-15 22:16:31 +08:00
from deepke.preprocess import process
from deepke.dataset import CustomDataset, collate_fn
2019-08-20 21:25:34 +08:00
__Models__ = {
"CNN": model.CNN,
2019-09-15 22:16:31 +08:00
"RNN": model.BiLSTM,
"GCN": model.GCN,
2019-08-20 21:25:34 +08:00
"Transformer": model.Transformer,
"Capsule": model.Capsule,
2019-09-15 22:16:31 +08:00
"LM": model.LM,
2019-08-20 21:25:34 +08:00
}
parser = argparse.ArgumentParser(description='choose your model')
2019-09-15 22:16:31 +08:00
parser.add_argument('--model_name', type=str, help='model name')
2019-08-20 21:25:34 +08:00
args = parser.parse_args()
model_name = args.model_name if args.model_name else config.model_name
2019-09-15 22:16:31 +08:00
make_seed(config.training.seed)
2019-08-20 21:25:34 +08:00
2019-09-15 22:16:31 +08:00
if config.training.use_gpu and torch.cuda.is_available():
2019-08-20 21:25:34 +08:00
device = torch.device('cuda', config.gpu_id)
else:
device = torch.device('cpu')
2019-09-15 22:16:31 +08:00
# if not os.path.exists(config.out_path):
process(config.data_path, config.out_path)
2019-08-20 21:25:34 +08:00
2019-09-15 22:16:31 +08:00
train_data_path = os.path.join(config.out_path, 'train.pkl')
test_data_path = os.path.join(config.out_path, 'test.pkl')
if model_name == 'LM':
vocab_size = None
2019-08-20 21:25:34 +08:00
else:
vocab_path = os.path.join(config.out_path, 'vocab.pkl')
2019-09-15 22:16:31 +08:00
vocab = load_pkl(vocab_path)
vocab_size = len(vocab.word2idx)
2019-08-20 21:25:34 +08:00
2019-09-15 22:16:31 +08:00
train_dataset = CustomDataset(train_data_path)
train_dataloader = DataLoader(train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
collate_fn=collate_fn)
test_dataset = CustomDataset(test_data_path)
test_dataloader = DataLoader(
test_dataset,
batch_size=config.training.batch_size,
shuffle=False,
collate_fn=collate_fn,
)
2019-08-20 21:25:34 +08:00
model = __Models__[model_name](vocab_size, config)
model.to(device)
2019-09-15 22:16:31 +08:00
# print(model)
2019-08-20 21:25:34 +08:00
2019-09-15 22:16:31 +08:00
optimizer = optim.Adam(model.parameters(), lr=config.training.learning_rate)
2019-08-20 21:25:34 +08:00
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
2019-09-15 22:16:31 +08:00
optimizer, 'max', factor=config.training.decay_rate, patience=config.training.decay_patience)
2019-08-20 21:25:34 +08:00
criterion = nn.CrossEntropyLoss()
best_macro_f1, best_macro_epoch = 0, 1
best_micro_f1, best_micro_epoch = 0, 1
best_macro_model, best_micro_model = '', ''
print('=' * 10, ' Start training ', '=' * 10)
2019-09-15 22:16:31 +08:00
for epoch in range(1, config.training.epoch + 1):
2019-08-20 21:25:34 +08:00
train(epoch, device, train_dataloader, model, optimizer, criterion, config)
macro_f1, micro_f1 = validate(test_dataloader, model, device, config)
model_name = model.save(epoch=epoch)
scheduler.step(macro_f1)
if macro_f1 > best_macro_f1:
best_macro_f1 = macro_f1
best_macro_epoch = epoch
best_macro_model = model_name
if micro_f1 > best_micro_f1:
best_micro_f1 = micro_f1
best_micro_epoch = epoch
best_micro_model = model_name
print('=' * 10, ' End training ', '=' * 10)
print(f'best macro f1: {best_macro_f1:.4f},',
f'in epoch: {best_macro_epoch}, saved in: {best_macro_model}')
print(f'best micro f1: {best_micro_f1:.4f},',
f'in epoch: {best_micro_epoch}, saved in: {best_micro_model}')