Merge pull request #1972 from tink2123/fix_eval_for_srn_2.0
polish code for srn eval
This commit is contained in:
commit
a3afc162fa
|
@ -182,6 +182,8 @@ def train(config,
|
||||||
model_average = False
|
model_average = False
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
start_epoch = best_model_dict['start_epoch']
|
start_epoch = best_model_dict['start_epoch']
|
||||||
else:
|
else:
|
||||||
|
@ -200,7 +202,7 @@ def train(config,
|
||||||
break
|
break
|
||||||
lr = optimizer.get_lr()
|
lr = optimizer.get_lr()
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
if config['Architecture']['algorithm'] == "SRN":
|
if use_srn:
|
||||||
others = batch[-4:]
|
others = batch[-4:]
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
model_average = True
|
model_average = True
|
||||||
|
@ -256,8 +258,12 @@ def train(config,
|
||||||
min_average_window=10000,
|
min_average_window=10000,
|
||||||
max_average_window=15625)
|
max_average_window=15625)
|
||||||
Model_Average.apply()
|
Model_Average.apply()
|
||||||
cur_metric = eval(model, valid_dataloader, post_process_class,
|
cur_metric = eval(
|
||||||
eval_class)
|
model,
|
||||||
|
valid_dataloader,
|
||||||
|
post_process_class,
|
||||||
|
eval_class,
|
||||||
|
use_srn=use_srn)
|
||||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||||
logger.info(cur_metric_str)
|
logger.info(cur_metric_str)
|
||||||
|
@ -321,7 +327,8 @@ def train(config,
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def eval(model, valid_dataloader, post_process_class, eval_class):
|
def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||||
|
use_srn=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
total_frame = 0.0
|
total_frame = 0.0
|
||||||
|
@ -332,7 +339,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
break
|
break
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
start = time.time()
|
start = time.time()
|
||||||
if "SRN" in str(model.head):
|
|
||||||
|
if use_srn:
|
||||||
others = batch[-4:]
|
others = batch[-4:]
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue