fix bug
This commit is contained in:
parent
9467b75436
commit
2aa92e6a64
|
@ -152,7 +152,6 @@ def train(config,
|
||||||
pre_best_model_dict,
|
pre_best_model_dict,
|
||||||
logger,
|
logger,
|
||||||
vdl_writer=None):
|
vdl_writer=None):
|
||||||
|
|
||||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||||
False)
|
False)
|
||||||
log_smooth_window = config['Global']['log_smooth_window']
|
log_smooth_window = config['Global']['log_smooth_window']
|
||||||
|
@ -185,14 +184,13 @@ def train(config,
|
||||||
|
|
||||||
for epoch in range(start_epoch, epoch_num):
|
for epoch in range(start_epoch, epoch_num):
|
||||||
if epoch > 0:
|
if epoch > 0:
|
||||||
train_loader = build_dataloader(config, 'Train', device)
|
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||||
|
|
||||||
for idx, batch in enumerate(train_dataloader):
|
for idx, batch in enumerate(train_dataloader):
|
||||||
if idx >= len(train_dataloader):
|
if idx >= len(train_dataloader):
|
||||||
break
|
break
|
||||||
lr = optimizer.get_lr()
|
lr = optimizer.get_lr()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
batch = [paddle.to_tensor(x) for x in batch]
|
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
total_frame = 0.0
|
total_frame = 0.0
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
|
pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
|
||||||
for idx, batch in enumerate(valid_dataloader):
|
for idx, batch in enumerate(valid_dataloader):
|
||||||
if idx >= len(valid_dataloader):
|
if idx >= len(valid_dataloader):
|
||||||
break
|
break
|
||||||
images = paddle.to_tensor(batch[0])
|
images = batch[0]
|
||||||
start = time.time()
|
start = time.time()
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
|
|
||||||
|
@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
|
||||||
total_time += time.time() - start
|
total_time += time.time() - start
|
||||||
# Evaluate the results of the current batch
|
# Evaluate the results of the current batch
|
||||||
eval_class(post_result, batch)
|
eval_class(post_result, batch)
|
||||||
# pbar.update(1)
|
pbar.update(1)
|
||||||
total_frame += len(images)
|
total_frame += len(images)
|
||||||
if idx % print_batch_step == 0 and dist.get_rank() == 0:
|
# if idx % print_batch_step == 0 and dist.get_rank() == 0:
|
||||||
logger.info('tackling images for eval: {}/{}'.format(
|
# logger.info('tackling images for eval: {}/{}'.format(
|
||||||
idx, len(valid_dataloader)))
|
# idx, len(valid_dataloader)))
|
||||||
# Get final metirc,eg. acc or hmean
|
# Get final metirc,eg. acc or hmean
|
||||||
metirc = eval_class.get_metric()
|
metirc = eval_class.get_metric()
|
||||||
|
|
||||||
# pbar.close()
|
pbar.close()
|
||||||
model.train()
|
model.train()
|
||||||
metirc['fps'] = total_frame / total_time
|
metirc['fps'] = total_frame / total_time
|
||||||
return metirc
|
return metirc
|
||||||
|
|
Loading…
Reference in New Issue