deepke/main.py

144 lines
5.0 KiB
Python

import os
import hydra
import torch
import logging
import torch.nn as nn
from torch import optim
from hydra import utils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# self
import models
from preprocess import preprocess
from dataset import CustomDataset, collate_fn
from trainer import train, validate
from utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)
@hydra.main(config_path='conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
logger.info(f'\n{cfg.pretty()}')
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# device
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
if cfg.preprocess:
preprocess(cfg)
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
if cfg.model_name == 'lm':
vocab_size = None
else:
vocab = load_pkl(vocab_path)
vocab_size = vocab.count
cfg.vocab_size = vocab_size
train_dataset = CustomDataset(train_data_path)
valid_dataset = CustomDataset(valid_data_path)
test_dataset = CustomDataset(test_data_path)
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
model = __Model__[cfg.model_name](cfg)
model.to(device)
logger.info(f'\n {model}')
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
criterion = nn.CrossEntropyLoss()
best_f1, best_epoch = -1, 0
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
train_losses, valid_losses = [], []
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
writer = SummaryWriter('tensorboard')
else:
writer = None
logger.info('=' * 10 + ' Start training ' + '=' * 10)
for epoch in range(1, cfg.epoch + 1):
manual_seed(cfg.seed + epoch)
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
scheduler.step(valid_loss)
model_path = model.save(epoch, cfg)
# logger.info(model_path)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
if best_f1 < valid_f1:
best_f1 = valid_f1
best_epoch = epoch
# 使用 valid loss 做 early stopping 的判断标准
if es_loss > valid_loss:
es_loss = valid_loss
es_f1 = valid_f1
es_epoch = epoch
es_patience = 0
es_path = model_path
else:
es_patience += 1
if es_patience >= cfg.early_stopping_patience:
best_es_epoch = es_epoch
best_es_f1 = es_f1
best_es_path = es_path
if cfg.show_plot:
if cfg.plot_utils == 'matplot':
plt.plot(train_losses, 'x-')
plt.plot(valid_losses, '+-')
plt.legend(['train', 'valid'])
plt.title('train/valid comparison loss')
plt.show()
if cfg.plot_utils == 'tensorboard':
for i in range(len(train_losses)):
writer.add_scalars('train/valid_comparison_loss', {
'train': train_losses[i],
'valid': valid_losses[i]
}, i)
writer.close()
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
f'this epoch macro f1: {best_es_f1:0.4f}')
logger.info(f'this model save path: {best_es_path}')
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
f'this epoch macro f1: {best_f1:.4f}')
validate(-1, model, test_dataloader, criterion, device, cfg)
if __name__ == '__main__':
main()
# python predict.py --help # 查看参数帮助
# python predict.py -c
# python predict.py chinese_split=0,1 replace_entity_with_type=0,1 -m