Merge pull request #1972 from tink2123/fix_eval_for_srn_2.0

polish code for srn eval
This commit is contained in:
xiaoting 2021-02-07 15:34:51 +08:00 committed by GitHub
commit a3afc162fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 5 deletions

View File

@ -182,6 +182,8 @@ def train(config,
model_average = False model_average = False
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
else: else:
@ -200,7 +202,7 @@ def train(config,
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
if config['Architecture']['algorithm'] == "SRN": if use_srn:
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
model_average = True model_average = True
@ -256,8 +258,12 @@ def train(config,
min_average_window=10000, min_average_window=10000,
max_average_window=15625) max_average_window=15625)
Model_Average.apply() Model_Average.apply()
cur_metric = eval(model, valid_dataloader, post_process_class, cur_metric = eval(
eval_class) model,
valid_dataloader,
post_process_class,
eval_class,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
@ -321,7 +327,8 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class): def eval(model, valid_dataloader, post_process_class, eval_class,
use_srn=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
@ -332,7 +339,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if "SRN" in str(model.head):
if use_srn:
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
else: else: