Merge pull request #1971 from tink2123/fix_eval_for_srn
polish code for srn eval
This commit is contained in:
commit
895d44bc39
|
@ -182,6 +182,8 @@ def train(config,
|
|||
model_average = False
|
||||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
else:
|
||||
|
@ -200,7 +202,7 @@ def train(config,
|
|||
break
|
||||
lr = optimizer.get_lr()
|
||||
images = batch[0]
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
model_average = True
|
||||
|
@ -256,8 +258,12 @@ def train(config,
|
|||
min_average_window=10000,
|
||||
max_average_window=15625)
|
||||
Model_Average.apply()
|
||||
cur_metric = eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
cur_metric = eval(
|
||||
model,
|
||||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
use_srn=use_srn)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
|
@ -321,7 +327,8 @@ def train(config,
|
|||
return
|
||||
|
||||
|
||||
def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||
def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||
use_srn=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -332,7 +339,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
if "SRN" in str(model.head):
|
||||
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue