commit
e634afe76c
|
@ -135,7 +135,7 @@ def main():
|
|||
|
||||
if alg in ['EAST', 'DB']:
|
||||
program.train_eval_det_run(
|
||||
config, exe, train_info_dict, eval_info_dict, is_pruning=True)
|
||||
config, exe, train_info_dict, eval_info_dict, is_slim="prune")
|
||||
else:
|
||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||
|
||||
|
|
|
@ -155,14 +155,13 @@ def main():
|
|||
act_preprocess_func=act_preprocess_func,
|
||||
optimizer_func=optimizer_func,
|
||||
executor=executor,
|
||||
for_test=False,
|
||||
return_program=True)
|
||||
for_test=False)
|
||||
|
||||
# compile program for multi-devices
|
||||
train_compile_program = program.create_multi_devices_program(
|
||||
quant_train_program, train_opt_loss_name, for_quant=True)
|
||||
|
||||
init_model(config, quant_train_program, exe)
|
||||
init_model(config, train_program, exe)
|
||||
|
||||
train_info_dict = {'compile_program':train_compile_program,\
|
||||
'train_program':quant_train_program,\
|
||||
|
@ -177,9 +176,11 @@ def main():
|
|||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
|
||||
if train_alg_type == 'det':
|
||||
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
|
||||
program.train_eval_det_run(
|
||||
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
|
||||
else:
|
||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||
program.train_eval_rec_run(
|
||||
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
159
tools/program.py
159
tools/program.py
|
@ -241,9 +241,11 @@ def create_multi_devices_program(program, loss_var_name, for_quant=False):
|
|||
build_strategy.enable_inplace = True
|
||||
if for_quant:
|
||||
build_strategy.fuse_all_reduce_ops = False
|
||||
else:
|
||||
program = fluid.CompiledProgram(program)
|
||||
exec_strategy = fluid.ExecutionStrategy()
|
||||
exec_strategy.num_iteration_per_drop_scope = 1
|
||||
compile_program = fluid.CompiledProgram(program).with_data_parallel(
|
||||
compile_program = program.with_data_parallel(
|
||||
loss_name=loss_var_name,
|
||||
build_strategy=build_strategy,
|
||||
exec_strategy=exec_strategy)
|
||||
|
@ -254,7 +256,7 @@ def train_eval_det_run(config,
|
|||
exe,
|
||||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_pruning=False):
|
||||
is_slim=None):
|
||||
'''
|
||||
main program of evaluation for detection
|
||||
'''
|
||||
|
@ -313,14 +315,21 @@ def train_eval_det_run(config,
|
|||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
if is_pruning:
|
||||
import paddleslim as slim
|
||||
slim.prune.save_model(
|
||||
exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(
|
||||
exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
|
||||
train_batch_id, metrics, best_eval_hmean, best_epoch,
|
||||
best_batch_id)
|
||||
|
@ -331,24 +340,42 @@ def train_eval_det_run(config,
|
|||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
if is_pruning:
|
||||
import paddleslim as slim
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
if is_pruning:
|
||||
import paddleslim as slim
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
return
|
||||
|
||||
|
||||
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
||||
def train_eval_rec_run(config,
|
||||
exe,
|
||||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_slim=None):
|
||||
'''
|
||||
main program of evaluation for recognition
|
||||
'''
|
||||
|
@ -428,7 +455,21 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(
|
||||
exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
||||
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
||||
best_batch_id, eval_sample_num)
|
||||
|
@ -439,14 +480,42 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
return
|
||||
|
||||
|
||||
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
||||
def train_eval_cls_run(config,
|
||||
exe,
|
||||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_slim=None):
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
|
@ -509,7 +578,21 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
|||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(
|
||||
exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
||||
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
||||
best_batch_id, eval_sample_num)
|
||||
|
@ -520,10 +603,34 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
|||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_slim is None:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
else:
|
||||
import paddleslim as slim
|
||||
if is_slim == "prune":
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
elif is_slim == "quant":
|
||||
save_model(eval_info_dict['program'], save_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Only quant and prune are supported currently. But received {}".
|
||||
format(is_slim))
|
||||
return
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue