add global_step to .states files (#2566)
Co-authored-by: littletomatodonkey <2120160898@bit.edu.cn>
This commit is contained in:
parent
8e4b213877
commit
c27022294e
|
@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
||||||
return best_model_dict
|
return best_model_dict
|
||||||
|
|
||||||
|
|
||||||
def save_model(net,
|
def save_model(model,
|
||||||
optimizer,
|
optimizer,
|
||||||
model_path,
|
model_path,
|
||||||
logger,
|
logger,
|
||||||
|
@ -133,7 +133,7 @@ def save_model(net,
|
||||||
"""
|
"""
|
||||||
_mkdir_if_not_exist(model_path, logger)
|
_mkdir_if_not_exist(model_path, logger)
|
||||||
model_prefix = os.path.join(model_path, prefix)
|
model_prefix = os.path.join(model_path, prefix)
|
||||||
paddle.save(net.state_dict(), model_prefix + '.pdparams')
|
paddle.save(model.state_dict(), model_prefix + '.pdparams')
|
||||||
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
|
||||||
|
|
||||||
# save metric and config
|
# save metric and config
|
||||||
|
|
|
@ -159,6 +159,8 @@ def train(config,
|
||||||
eval_batch_step = config['Global']['eval_batch_step']
|
eval_batch_step = config['Global']['eval_batch_step']
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
if 'global_step' in pre_best_model_dict:
|
||||||
|
global_step = pre_best_model_dict['global_step']
|
||||||
start_eval_step = 0
|
start_eval_step = 0
|
||||||
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
||||||
start_eval_step = eval_batch_step[0]
|
start_eval_step = eval_batch_step[0]
|
||||||
|
@ -285,7 +287,8 @@ def train(config,
|
||||||
is_best=True,
|
is_best=True,
|
||||||
prefix='best_accuracy',
|
prefix='best_accuracy',
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch,
|
||||||
|
global_step=global_step)
|
||||||
best_str = 'best metric, {}'.format(', '.join([
|
best_str = 'best metric, {}'.format(', '.join([
|
||||||
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
||||||
]))
|
]))
|
||||||
|
@ -307,7 +310,8 @@ def train(config,
|
||||||
is_best=False,
|
is_best=False,
|
||||||
prefix='latest',
|
prefix='latest',
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch,
|
||||||
|
global_step=global_step)
|
||||||
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
|
@ -317,7 +321,8 @@ def train(config,
|
||||||
is_best=False,
|
is_best=False,
|
||||||
prefix='iter_epoch_{}'.format(epoch),
|
prefix='iter_epoch_{}'.format(epoch),
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch,
|
||||||
|
global_step=global_step)
|
||||||
best_str = 'best metric, {}'.format(', '.join(
|
best_str = 'best metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||||
logger.info(best_str)
|
logger.info(best_str)
|
||||||
|
|
Loading…
Reference in New Issue