optimize the prune

This commit is contained in:
LDOUBLEV 2021-09-26 15:09:48 +08:00
parent e2ed89fa79
commit 9dba4a1214
2 changed files with 35 additions and 18 deletions

View File

@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer):
logger.info("metric['hmean']: {}".format(metric['hmean']))
return metric['hmean']
run_sensitive_analysis = False
"""
run_sensitive_analysis=True:
Automatically compute the sensitivities of convolutions in a model.
The sensitivity of a convolution is the losses of accuracy on test dataset in
differenct pruned ratios. The sensitivities can be used to get a group of best
ratios with some condition.
run_sensitive_analysis=False:
Set prune trim ratio to a fixed value, such as 10%. The larger the value,
the more convolution weights will be cropped.
"""
if run_sensitive_analysis:
params_sensitive = pruner.sensitive(
eval_func=eval_fn,
sen_file="./sen.pickle",
sen_file="./deploy/slim/prune/sen.pickle",
skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
"conv2d_57.w_0", "conv2d_transpose_2.w_0",
"conv2d_transpose_3.w_0"
])
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
params_sensitive = pruner._get_ratios_by_loss(
params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info("{}, {}".format(key, params_sensitive[key]))
#params_sensitive = {}
#for param in model.parameters():
# if 'transpose' not in param.name and 'linear' not in param.name:
# params_sensitive[param.name] = 0.1
else:
params_sensitive = {}
for param in model.parameters():
if 'transpose' not in param.name and 'linear' not in param.name:
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
params_sensitive[param.name] = 0.1
plan = pruner.prune_vars(params_sensitive, [0])

View File

@ -351,7 +351,7 @@ def eval(model,
valid_dataloader,
post_process_class,
eval_class,
model_type,
model_type=None,
use_srn=False,
use_sar=False):
model.eval()