set evaluation interval

This commit is contained in:
LDOUBLEV 2020-07-07 02:35:17 +00:00
parent 5067126e8c
commit e222bc9e64
3 changed files with 18 additions and 4 deletions

View File

@ -22,7 +22,7 @@
| print_batch_step | 设置打印log间隔 | 10 | \ | | print_batch_step | 设置打印log间隔 | 10 | \ |
| save_model_dir | 设置模型保存路径 | output/{算法名称} | \ | | save_model_dir | 设置模型保存路径 | output/{算法名称} | \ |
| save_epoch_step | 设置模型保存间隔 | 3 | \ | | save_epoch_step | 设置模型保存间隔 | 3 | \ |
| eval_batch_step | 设置模型评估间隔 | 2000 | \ | | eval_batch_step | 设置模型评估间隔 | 2000 或 [1000, 2000] | 2000 表示每2000次迭代评估一次[1000 2000]表示从1000次迭代开始每2000次评估一次 |
|train_batch_size_per_card | 设置训练时单卡batch size | 256 | \ | |train_batch_size_per_card | 设置训练时单卡batch size | 256 | \ |
| test_batch_size_per_card | 设置评估时单卡batch size | 256 | \ | | test_batch_size_per_card | 设置评估时单卡batch size | 256 | \ |
| image_shape | 设置输入图片尺寸 | [3, 32, 100] | \ | | image_shape | 设置输入图片尺寸 | [3, 32, 100] | \ |

View File

@ -22,7 +22,7 @@ Take `rec_chinese_lite_train.yml` as an example
| print_batch_step | Set print log interval | 10 | \ | | print_batch_step | Set print log interval | 10 | \ |
| save_model_dir | Set model save path | output/{model_name} | \ | | save_model_dir | Set model save path | output/{model_name} | \ |
| save_epoch_step | Set model save interval | 3 | \ | | save_epoch_step | Set model save interval | 3 | \ |
| eval_batch_step | Set the model evaluation interval | 2000 | \ | | eval_batch_step | Set the model evaluation interval |2000 or [1000, 2000] |runing evaluation every 2000 iters or evaluation is run every 2000 iterations after the 1000th iteration |
|train_batch_size_per_card | Set the batch size during training | 256 | \ | |train_batch_size_per_card | Set the batch size during training | 256 | \ |
| test_batch_size_per_card | Set the batch size during testing | 256 | \ | | test_batch_size_per_card | Set the batch size during testing | 256 | \ |
| image_shape | Set input image size | [3, 32, 100] | \ | | image_shape | Set input image size | [3, 32, 100] | \ |

View File

@ -219,6 +219,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step'] save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
if not os.path.exists(save_model_dir): if not os.path.exists(save_model_dir):
@ -246,7 +253,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
t2 = time.time() t2 = time.time()
train_batch_elapse = t2 - t1 train_batch_elapse = t2 - t1
train_stats.update(stats) train_stats.update(stats)
if train_batch_id > 0 and train_batch_id \ if train_batch_id > start_eval_step and train_batch_id \
% print_batch_step == 0: % print_batch_step == 0:
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format( strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format(
@ -286,6 +293,13 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step'] save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
if not os.path.exists(save_model_dir): if not os.path.exists(save_model_dir):
@ -324,7 +338,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_batch_elapse = t2 - t1 train_batch_elapse = t2 - t1
stats = {'loss': loss, 'acc': acc} stats = {'loss': loss, 'acc': acc}
train_stats.update(stats) train_stats.update(stats)
if train_batch_id > 0 and train_batch_id \ if train_batch_id > start_eval_step and train_batch_id \
% print_batch_step == 0: % print_batch_step == 0:
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format( strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(