optimize the prune
This commit is contained in:
parent
e2ed89fa79
commit
9dba4a1214
|
@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer):
|
|||
logger.info("metric['hmean']: {}".format(metric['hmean']))
|
||||
return metric['hmean']
|
||||
|
||||
run_sensitive_analysis = False
|
||||
"""
|
||||
run_sensitive_analysis=True:
|
||||
Automatically compute the sensitivities of convolutions in a model.
|
||||
The sensitivity of a convolution is the losses of accuracy on test dataset in
|
||||
differenct pruned ratios. The sensitivities can be used to get a group of best
|
||||
ratios with some condition.
|
||||
|
||||
run_sensitive_analysis=False:
|
||||
Set prune trim ratio to a fixed value, such as 10%. The larger the value,
|
||||
the more convolution weights will be cropped.
|
||||
|
||||
"""
|
||||
|
||||
if run_sensitive_analysis:
|
||||
params_sensitive = pruner.sensitive(
|
||||
eval_func=eval_fn,
|
||||
sen_file="./sen.pickle",
|
||||
sen_file="./deploy/slim/prune/sen.pickle",
|
||||
skip_vars=[
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0",
|
||||
"conv2d_transpose_3.w_0"
|
||||
])
|
||||
|
||||
logger.info(
|
||||
"The sensitivity analysis results of model parameters saved in sen.pickle"
|
||||
)
|
||||
# calculate pruned params's ratio
|
||||
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
|
||||
params_sensitive = pruner._get_ratios_by_loss(
|
||||
params_sensitive, loss=0.02)
|
||||
for key in params_sensitive.keys():
|
||||
logger.info("{}, {}".format(key, params_sensitive[key]))
|
||||
|
||||
#params_sensitive = {}
|
||||
#for param in model.parameters():
|
||||
# if 'transpose' not in param.name and 'linear' not in param.name:
|
||||
# params_sensitive[param.name] = 0.1
|
||||
else:
|
||||
params_sensitive = {}
|
||||
for param in model.parameters():
|
||||
if 'transpose' not in param.name and 'linear' not in param.name:
|
||||
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
|
||||
params_sensitive[param.name] = 0.1
|
||||
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
|
||||
|
|
|
@ -351,7 +351,7 @@ def eval(model,
|
|||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
model_type=None,
|
||||
use_srn=False,
|
||||
use_sar=False):
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in New Issue