This commit is contained in:
LDOUBLEV 2021-01-26 15:16:02 +08:00
parent 5a5d627deb
commit 09fd94e781
1 changed files with 15 additions and 15 deletions

View File

@ -212,7 +212,7 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if cal_metric_during_train: # onlt rec and cls need if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch) eval_class(post_result, batch)
@ -238,21 +238,21 @@ def train(config,
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class, cur_metric = eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
cur_metirc_str = 'cur metirc, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metirc_str) logger.info(cur_metric_str)
# logger metric # logger metric
if vdl_writer is not None: if vdl_writer is not None:
for k, v in cur_metirc.items(): for k, v in cur_metric.items():
if isinstance(v, (float, int)): if isinstance(v, (float, int)):
vdl_writer.add_scalar('EVAL/{}'.format(k), vdl_writer.add_scalar('EVAL/{}'.format(k),
cur_metirc[k], global_step) cur_metric[k], global_step)
if cur_metirc[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
best_model_dict.update(cur_metirc) best_model_dict.update(cur_metric)
best_model_dict['best_epoch'] = epoch best_model_dict['best_epoch'] = epoch
save_model( save_model(
model, model,
@ -263,7 +263,7 @@ def train(config,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join([ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items() '{}: {}'.format(k, v) for k, v in best_model_dict.items()
])) ]))
logger.info(best_str) logger.info(best_str)
@ -294,7 +294,7 @@ def train(config,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
if dist.get_rank() == 0 and vdl_writer is not None: if dist.get_rank() == 0 and vdl_writer is not None:
@ -323,13 +323,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
# Get final metirceg. acc or hmean # Get final metriceg. acc or hmean
metirc = eval_class.get_metric() metric = eval_class.get_metric()
pbar.close() pbar.close()
model.train() model.train()
metirc['fps'] = total_frame / total_time metric['fps'] = total_frame / total_time
return metirc return metric
def preprocess(is_train=False): def preprocess(is_train=False):