fix typo
This commit is contained in:
parent
5a5d627deb
commit
09fd94e781
|
@ -212,7 +212,7 @@ def train(config,
|
||||||
stats['lr'] = lr
|
stats['lr'] = lr
|
||||||
train_stats.update(stats)
|
train_stats.update(stats)
|
||||||
|
|
||||||
if cal_metric_during_train: # onlt rec and cls need
|
if cal_metric_during_train: # only rec and cls need
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
post_result = post_process_class(preds, batch[1])
|
post_result = post_process_class(preds, batch[1])
|
||||||
eval_class(post_result, batch)
|
eval_class(post_result, batch)
|
||||||
|
@ -238,21 +238,21 @@ def train(config,
|
||||||
# eval
|
# eval
|
||||||
if global_step > start_eval_step and \
|
if global_step > start_eval_step and \
|
||||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
cur_metric = eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class)
|
eval_class)
|
||||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
|
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||||
logger.info(cur_metirc_str)
|
logger.info(cur_metric_str)
|
||||||
|
|
||||||
# logger metric
|
# logger metric
|
||||||
if vdl_writer is not None:
|
if vdl_writer is not None:
|
||||||
for k, v in cur_metirc.items():
|
for k, v in cur_metric.items():
|
||||||
if isinstance(v, (float, int)):
|
if isinstance(v, (float, int)):
|
||||||
vdl_writer.add_scalar('EVAL/{}'.format(k),
|
vdl_writer.add_scalar('EVAL/{}'.format(k),
|
||||||
cur_metirc[k], global_step)
|
cur_metric[k], global_step)
|
||||||
if cur_metirc[main_indicator] >= best_model_dict[
|
if cur_metric[main_indicator] >= best_model_dict[
|
||||||
main_indicator]:
|
main_indicator]:
|
||||||
best_model_dict.update(cur_metirc)
|
best_model_dict.update(cur_metric)
|
||||||
best_model_dict['best_epoch'] = epoch
|
best_model_dict['best_epoch'] = epoch
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
|
@ -263,7 +263,7 @@ def train(config,
|
||||||
prefix='best_accuracy',
|
prefix='best_accuracy',
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch)
|
||||||
best_str = 'best metirc, {}'.format(', '.join([
|
best_str = 'best metric, {}'.format(', '.join([
|
||||||
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
||||||
]))
|
]))
|
||||||
logger.info(best_str)
|
logger.info(best_str)
|
||||||
|
@ -294,7 +294,7 @@ def train(config,
|
||||||
prefix='iter_epoch_{}'.format(epoch),
|
prefix='iter_epoch_{}'.format(epoch),
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch)
|
||||||
best_str = 'best metirc, {}'.format(', '.join(
|
best_str = 'best metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||||
logger.info(best_str)
|
logger.info(best_str)
|
||||||
if dist.get_rank() == 0 and vdl_writer is not None:
|
if dist.get_rank() == 0 and vdl_writer is not None:
|
||||||
|
@ -323,13 +323,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
eval_class(post_result, batch)
|
eval_class(post_result, batch)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
total_frame += len(images)
|
total_frame += len(images)
|
||||||
# Get final metirc,eg. acc or hmean
|
# Get final metric,eg. acc or hmean
|
||||||
metirc = eval_class.get_metric()
|
metric = eval_class.get_metric()
|
||||||
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
model.train()
|
model.train()
|
||||||
metirc['fps'] = total_frame / total_time
|
metric['fps'] = total_frame / total_time
|
||||||
return metirc
|
return metric
|
||||||
|
|
||||||
|
|
||||||
def preprocess(is_train=False):
|
def preprocess(is_train=False):
|
||||||
|
|
Loading…
Reference in New Issue