update config

This commit is contained in:
leo 2019-12-05 21:42:15 +08:00
parent 1f92922c2e
commit ecb2ac6dc9
3 changed files with 16 additions and 7 deletions

View File

@ -23,7 +23,6 @@ def main(cfg):
cwd = utils.get_original_cwd()
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
logger.info(f'\n{cfg.pretty()}')
__Model__ = {
'cnn': models.PCNN,
@ -67,6 +66,8 @@ def main(cfg):
model = __Model__[cfg.model_name](cfg)
model.to(device)
logger.info(f'\n{cfg.pretty()}')
logger.info(f'\n {model}')
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
@ -87,7 +88,7 @@ def main(cfg):
for epoch in range(1, cfg.epoch + 1):
manual_seed(cfg.seed + epoch)
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device)
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
scheduler.step(valid_loss)
model_path = model.save(epoch, cfg)
# logger.info(model_path)
@ -133,7 +134,7 @@ def main(cfg):
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
f'this epoch macro f1: {best_f1:.4f}')
validate(-1, model, test_dataloader, criterion, device)
validate(-1, model, test_dataloader, criterion, device, cfg)
if __name__ == '__main__':

View File

@ -19,7 +19,11 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
optimizer.zero_grad()
y_pred = model(x)
loss = criterion(y_pred, y)
if cfg.model_name == 'capsule':
loss = model.loss(y_pred, y)
else:
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
@ -50,7 +54,7 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
return losses[-1]
def validate(epoch, model, dataloader, criterion, device):
def validate(epoch, model, dataloader, criterion, device, cfg):
model.eval()
metric = PRMetric()
@ -63,7 +67,11 @@ def validate(epoch, model, dataloader, criterion, device):
with torch.no_grad():
y_pred = model(x)
loss = criterion(y_pred, y)
if cfg.model_name == 'capsule':
loss = model.loss(y_pred, y)
else:
loss = criterion(y_pred, y)
metric.update(y_true=y, y_pred=y_pred)
losses.append(loss.item())

View File

@ -63,4 +63,4 @@ def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
for i in range(B):
x_one_hot[i, x[i]] = 1.0
return x_one_hot
return x_one_hot.to(device=x.device)