fix
This commit is contained in:
parent
0e61eed58b
commit
94d23febef
200
tools/trainer.py
200
tools/trainer.py
|
@ -1,113 +1,117 @@
|
|||
import torch
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Sequence, Optional
|
||||
import matplotlib.pyplot as plt
|
||||
from metrics import PRMetric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SPECIAL_TOKENS_KEYS = [
|
||||
"pad_token",
|
||||
"unk_token",
|
||||
"mask_token",
|
||||
"cls_token",
|
||||
"sep_token",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"head_token",
|
||||
"tail_token",
|
||||
|
||||
]
|
||||
|
||||
SPECIAL_TOKENS_VALUES = [
|
||||
"[PAD]",
|
||||
"[UNK]",
|
||||
"[MASK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[BOS]",
|
||||
"[EOS]",
|
||||
"HEAD",
|
||||
"TAIL",
|
||||
]
|
||||
|
||||
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
||||
"""
|
||||
构建词汇表,增加词汇,删除低频词汇
|
||||
training the model.
|
||||
Args:
|
||||
epoch (int): number of training steps.
|
||||
model (class): model of training.
|
||||
dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
|
||||
optimizer (Callable): optimizer of training.
|
||||
criterion (Callable): loss criterion of training.
|
||||
device (torch.device): device of training.
|
||||
writer (class): output to tensorboard.
|
||||
cfg: configutation of training.
|
||||
Return:
|
||||
losses[-1] : the loss of training
|
||||
"""
|
||||
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
|
||||
self.name = name
|
||||
self.init_tokens = init_tokens
|
||||
self.trimed = False
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self._add_init_tokens()
|
||||
model.train()
|
||||
|
||||
def _add_init_tokens(self):
|
||||
"""
|
||||
添加初始tokens
|
||||
"""
|
||||
for token in self.init_tokens.values():
|
||||
self._add_word(token)
|
||||
metric = PRMetric()
|
||||
losses = []
|
||||
|
||||
def _add_word(self, word: str):
|
||||
"""
|
||||
增加单个词汇
|
||||
Arg :
|
||||
word (String) : 增加的词汇
|
||||
"""
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = self.count
|
||||
self.word2count[word] = 1
|
||||
self.idx2word[self.count] = word
|
||||
self.count += 1
|
||||
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||
for key, value in x.items():
|
||||
x[key] = value.to(device)
|
||||
|
||||
y = y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
y_pred = model(x)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
self.word2count[word] += 1
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
def add_words(self, words: Sequence):
|
||||
"""
|
||||
通过数组增加词汇
|
||||
Arg :
|
||||
words (List) : 增加的词汇组
|
||||
"""
|
||||
for word in words:
|
||||
self._add_word(word)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def trim(self, min_freq=2, verbose: Optional[bool] = True):
|
||||
"""
|
||||
当 word 词频低于 min_freq 时,从词库中删除
|
||||
Args:
|
||||
min_freq (int): 最低词频
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
|
||||
min_freq = int(min_freq)
|
||||
if min_freq < 2:
|
||||
return
|
||||
if self.trimed:
|
||||
return
|
||||
self.trimed = True
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
|
||||
keep_words = []
|
||||
new_words = []
|
||||
data_total = len(dataloader.dataset)
|
||||
data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)
|
||||
if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):
|
||||
# p r f1 皆为 macro,因为micro时三者相同,定义为acc
|
||||
acc, p, r, f1 = metric.compute()
|
||||
logger.info(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\t'
|
||||
f'Loss: {loss.item():.6f}')
|
||||
logger.info(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\t'
|
||||
f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
|
||||
for k, v in self.word2count.items():
|
||||
if v >= min_freq:
|
||||
keep_words.append(k)
|
||||
new_words.extend([k] * v)
|
||||
if verbose:
|
||||
before_len = len(keep_words)
|
||||
after_len = len(self.word2idx) - len(self.init_tokens)
|
||||
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
|
||||
before_len, after_len, before_len / after_len * 100))
|
||||
if cfg.show_plot and not cfg.only_comparison_plot:
|
||||
if cfg.plot_utils == 'matplot':
|
||||
plt.plot(losses)
|
||||
plt.title(f'epoch {epoch} train loss')
|
||||
plt.show()
|
||||
|
||||
# Reinitialize dictionaries
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self._add_init_tokens()
|
||||
self.add_words(new_words)
|
||||
if cfg.plot_utils == 'tensorboard':
|
||||
for i in range(len(losses)):
|
||||
writer.add_scalar(f'epoch_{epoch}_training_loss', losses[i], i)
|
||||
|
||||
return losses[-1]
|
||||
|
||||
|
||||
def validate(epoch, model, dataloader, criterion, device, cfg):
|
||||
"""
|
||||
validating the model.
|
||||
Args:
|
||||
epoch (int): number of validating steps.
|
||||
model (class): model of validating.
|
||||
dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
|
||||
criterion (Callable): loss criterion of validating.
|
||||
device (torch.device): device of validating.
|
||||
cfg: configutation of validating.
|
||||
Return:
|
||||
f1 : f1 score
|
||||
loss : the loss of validating
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
metric = PRMetric()
|
||||
losses = []
|
||||
|
||||
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||
for key, value in x.items():
|
||||
x[key] = value.to(device)
|
||||
y = y.to(device)
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
|
||||
loss = sum(losses) / len(losses)
|
||||
acc, p, r, f1 = metric.compute()
|
||||
data_total = len(dataloader.dataset)
|
||||
|
||||
if epoch >= 0:
|
||||
logger.info(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||
logger.info(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
else:
|
||||
logger.info(f'Test Data: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||
logger.info(f'Test Data: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
|
||||
return f1, loss
|
||||
|
|
Loading…
Reference in New Issue