refine
This commit is contained in:
parent
053cc43d82
commit
97a668747c
|
@ -187,7 +187,7 @@ def train(config,
|
|||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
else:
|
||||
|
@ -338,8 +338,12 @@ def train(config,
|
|||
return
|
||||
|
||||
|
||||
def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||
model_type, use_srn=False):
|
||||
def eval(model,
|
||||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -352,7 +356,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
preds = model(images, data=batch[1:])
|
||||
preds = model(images, data=batch[1:])
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
total_time += time.time() - start
|
||||
|
|
Loading…
Reference in New Issue