commit
fc7b5d225b
|
@ -65,5 +65,8 @@ class TrainingStats(object):
|
|||
|
||||
def log(self, extras=None):
|
||||
d = self.get(extras)
|
||||
strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items())
|
||||
strs = []
|
||||
for k, v in d.items():
|
||||
strs.append('{}: {:x<6f}'.format(k, v))
|
||||
strs = ', '.join(strs)
|
||||
return strs
|
||||
|
|
|
@ -185,12 +185,15 @@ def train(config,
|
|||
for epoch in range(start_epoch, epoch_num):
|
||||
if epoch > 0:
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
|
||||
train_batch_cost = 0.0
|
||||
train_reader_cost = 0.0
|
||||
batch_sum = 0
|
||||
batch_start = time.time()
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
train_reader_cost += time.time() - batch_start
|
||||
if idx >= len(train_dataloader):
|
||||
break
|
||||
lr = optimizer.get_lr()
|
||||
t1 = time.time()
|
||||
images = batch[0]
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
|
@ -198,6 +201,10 @@ def train(config,
|
|||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
train_batch_cost += time.time() - batch_start
|
||||
batch_sum += len(images)
|
||||
|
||||
if not isinstance(lr_scheduler, float):
|
||||
lr_scheduler.step()
|
||||
|
||||
|
@ -213,9 +220,6 @@ def train(config,
|
|||
metirc = eval_class.get_metric()
|
||||
train_stats.update(metirc)
|
||||
|
||||
t2 = time.time()
|
||||
train_batch_elapse = t2 - t1
|
||||
|
||||
if vdl_writer is not None and dist.get_rank() == 0:
|
||||
for k, v in train_stats.get().items():
|
||||
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
|
||||
|
@ -224,9 +228,15 @@ def train(config,
|
|||
if dist.get_rank(
|
||||
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
|
||||
logs = train_stats.log()
|
||||
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
|
||||
epoch, epoch_num, global_step, logs, train_batch_elapse)
|
||||
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
|
||||
epoch, epoch_num, global_step, logs, train_reader_cost /
|
||||
print_batch_step, train_batch_cost / print_batch_step,
|
||||
batch_sum, batch_sum / train_batch_cost)
|
||||
logger.info(strs)
|
||||
train_batch_cost = 0.0
|
||||
train_reader_cost = 0.0
|
||||
batch_sum = 0
|
||||
batch_start = time.time()
|
||||
# eval
|
||||
if global_step > start_eval_step and \
|
||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||
|
|
Loading…
Reference in New Issue