update config
This commit is contained in:
parent
1f92922c2e
commit
ecb2ac6dc9
7
main.py
7
main.py
|
@ -23,7 +23,6 @@ def main(cfg):
|
|||
cwd = utils.get_original_cwd()
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
logger.info(f'\n{cfg.pretty()}')
|
||||
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
|
@ -67,6 +66,8 @@ def main(cfg):
|
|||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
model.to(device)
|
||||
|
||||
logger.info(f'\n{cfg.pretty()}')
|
||||
logger.info(f'\n {model}')
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
|
@ -87,7 +88,7 @@ def main(cfg):
|
|||
for epoch in range(1, cfg.epoch + 1):
|
||||
manual_seed(cfg.seed + epoch)
|
||||
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device)
|
||||
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||
scheduler.step(valid_loss)
|
||||
model_path = model.save(epoch, cfg)
|
||||
# logger.info(model_path)
|
||||
|
@ -133,7 +134,7 @@ def main(cfg):
|
|||
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||
f'this epoch macro f1: {best_f1:.4f}')
|
||||
|
||||
validate(-1, model, test_dataloader, criterion, device)
|
||||
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
14
trainer.py
14
trainer.py
|
@ -19,7 +19,11 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
|||
|
||||
optimizer.zero_grad()
|
||||
y_pred = model(x)
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -50,7 +54,7 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
|||
return losses[-1]
|
||||
|
||||
|
||||
def validate(epoch, model, dataloader, criterion, device):
|
||||
def validate(epoch, model, dataloader, criterion, device, cfg):
|
||||
model.eval()
|
||||
|
||||
metric = PRMetric()
|
||||
|
@ -63,7 +67,11 @@ def validate(epoch, model, dataloader, criterion, device):
|
|||
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
|
|
|
@ -63,4 +63,4 @@ def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
|
|||
for i in range(B):
|
||||
x_one_hot[i, x[i]] = 1.0
|
||||
|
||||
return x_one_hot
|
||||
return x_one_hot.to(device=x.device)
|
Loading…
Reference in New Issue