From 94d23febefc3f83ecbfe15127150ae6999babcff Mon Sep 17 00:00:00 2001 From: tlk-dsg <467460833@qq.com> Date: Thu, 17 Jun 2021 14:35:53 +0800 Subject: [PATCH] fix --- tools/trainer.py | 200 ++++++++++++++++++++++++----------------------- 1 file changed, 102 insertions(+), 98 deletions(-) diff --git a/tools/trainer.py b/tools/trainer.py index 43d0154..6484c24 100644 --- a/tools/trainer.py +++ b/tools/trainer.py @@ -1,113 +1,117 @@ +import torch import logging -from collections import OrderedDict -from typing import Sequence, Optional +import matplotlib.pyplot as plt +from metrics import PRMetric logger = logging.getLogger(__name__) -SPECIAL_TOKENS_KEYS = [ - "pad_token", - "unk_token", - "mask_token", - "cls_token", - "sep_token", - "bos_token", - "eos_token", - "head_token", - "tail_token", -] - -SPECIAL_TOKENS_VALUES = [ - "[PAD]", - "[UNK]", - "[MASK]", - "[CLS]", - "[SEP]", - "[BOS]", - "[EOS]", - "HEAD", - "TAIL", -] - -SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES)) - - -class Vocab(object): +def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg): """ - 构建词汇表,增加词汇,删除低频词汇 + training the model. + Args: + epoch (int): number of training steps. + model (class): model of training. + dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders. + optimizer (Callable): optimizer of training. + criterion (Callable): loss criterion of training. + device (torch.device): device of training. + writer (class): output to tensorboard. + cfg: configutation of training. + Return: + losses[-1] : the loss of training """ - def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS): - self.name = name - self.init_tokens = init_tokens - self.trimed = False - self.word2idx = {} - self.word2count = {} - self.idx2word = {} - self.count = 0 - self._add_init_tokens() + model.train() - def _add_init_tokens(self): - """ - 添加初始tokens - """ - for token in self.init_tokens.values(): - self._add_word(token) + metric = PRMetric() + losses = [] - def _add_word(self, word: str): - """ - 增加单个词汇 - Arg : - word (String) : 增加的词汇 - """ - if word not in self.word2idx: - self.word2idx[word] = self.count - self.word2count[word] = 1 - self.idx2word[self.count] = word - self.count += 1 + for batch_idx, (x, y) in enumerate(dataloader, 1): + for key, value in x.items(): + x[key] = value.to(device) + + y = y.to(device) + + optimizer.zero_grad() + y_pred = model(x) + + if cfg.model_name == 'capsule': + loss = model.loss(y_pred, y) else: - self.word2count[word] += 1 + loss = criterion(y_pred, y) - def add_words(self, words: Sequence): - """ - 通过数组增加词汇 - Arg : - words (List) : 增加的词汇组 - """ - for word in words: - self._add_word(word) + loss.backward() + optimizer.step() - def trim(self, min_freq=2, verbose: Optional[bool] = True): - """ - 当 word 词频低于 min_freq 时,从词库中删除 - Args: - min_freq (int): 最低词频 - verbose (bool) : 是否打印日志 - """ - assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}' - min_freq = int(min_freq) - if min_freq < 2: - return - if self.trimed: - return - self.trimed = True + metric.update(y_true=y, y_pred=y_pred) + losses.append(loss.item()) - keep_words = [] - new_words = [] + data_total = len(dataloader.dataset) + data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y) + if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader): + # p r f1 皆为 macro,因为micro时三者相同,定义为acc + acc, p, r, f1 = metric.compute() + logger.info(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\t' + f'Loss: {loss.item():.6f}') + logger.info(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\t' + f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]') - for k, v in self.word2count.items(): - if v >= min_freq: - keep_words.append(k) - new_words.extend([k] * v) - if verbose: - before_len = len(keep_words) - after_len = len(self.word2idx) - len(self.init_tokens) - logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format( - before_len, after_len, before_len / after_len * 100)) + if cfg.show_plot and not cfg.only_comparison_plot: + if cfg.plot_utils == 'matplot': + plt.plot(losses) + plt.title(f'epoch {epoch} train loss') + plt.show() - # Reinitialize dictionaries - self.word2idx = {} - self.word2count = {} - self.idx2word = {} - self.count = 0 - self._add_init_tokens() - self.add_words(new_words) + if cfg.plot_utils == 'tensorboard': + for i in range(len(losses)): + writer.add_scalar(f'epoch_{epoch}_training_loss', losses[i], i) + + return losses[-1] + + +def validate(epoch, model, dataloader, criterion, device, cfg): + """ + validating the model. + Args: + epoch (int): number of validating steps. + model (class): model of validating. + dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders. + criterion (Callable): loss criterion of validating. + device (torch.device): device of validating. + cfg: configutation of validating. + Return: + f1 : f1 score + loss : the loss of validating + """ + model.eval() + + metric = PRMetric() + losses = [] + + for batch_idx, (x, y) in enumerate(dataloader, 1): + for key, value in x.items(): + x[key] = value.to(device) + y = y.to(device) + with torch.no_grad(): + y_pred = model(x) + + if cfg.model_name == 'capsule': + loss = model.loss(y_pred, y) + else: + loss = criterion(y_pred, y) + + metric.update(y_true=y, y_pred=y_pred) + losses.append(loss.item()) + + loss = sum(losses) / len(losses) + acc, p, r, f1 = metric.compute() + data_total = len(dataloader.dataset) + + if epoch >= 0: + logger.info(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}') + logger.info(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]') + else: + logger.info(f'Test Data: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}') + logger.info(f'Test Data: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]') + + return f1, loss