删除eval多余的参数
This commit is contained in:
parent
4eba6c0dce
commit
672318256c
|
@ -231,7 +231,7 @@ def train(config,
|
||||||
if global_step > start_eval_step and \
|
if global_step > start_eval_step and \
|
||||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class, logger, print_batch_step)
|
eval_class)
|
||||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
|
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
|
||||||
logger.info(cur_metirc_str)
|
logger.info(cur_metirc_str)
|
||||||
|
@ -293,8 +293,7 @@ def train(config,
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def eval(model, valid_dataloader, post_process_class, eval_class, logger,
|
def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
print_batch_step):
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
total_frame = 0.0
|
total_frame = 0.0
|
||||||
|
@ -315,9 +314,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
|
||||||
eval_class(post_result, batch)
|
eval_class(post_result, batch)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
total_frame += len(images)
|
total_frame += len(images)
|
||||||
# if idx % print_batch_step == 0 and dist.get_rank() == 0:
|
|
||||||
# logger.info('tackling images for eval: {}/{}'.format(
|
|
||||||
# idx, len(valid_dataloader)))
|
|
||||||
# Get final metirc,eg. acc or hmean
|
# Get final metirc,eg. acc or hmean
|
||||||
metirc = eval_class.get_metric()
|
metirc = eval_class.get_metric()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue