revert prune

This commit is contained in:
LDOUBLEV 2021-05-12 10:29:11 +08:00
parent b79fee116a
commit 76f404690e
1 changed files with 17 additions and 31 deletions

View File

@ -24,14 +24,6 @@ sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import json
import cv2
import paddle
from paddle import fluid
import paddleslim as slim
from copy import deepcopy
from tools import program
import paddle
import paddle.distributed as dist
from ppocr.data import build_dataloader
@ -46,28 +38,14 @@ import tools.program as program
dist.get_world_size()
def get_pruned_params(parameters, mode="det"):
if mode == "det":
skip_prune_params = [
"conv2d_56.w_0", "conv2d_54.w_0", "conv2d_51.w_0",
"conv_last_weights", "conv14_linear_weights",
"conv13_expand_weights", "conv12_linear_weights",
"conv12_expand_weights", "conv7_expand_weights",
"conv8_expand_weights", "conv8_linear_weights",
"conv5_linear_weights", "conv5_expand_weights",
"conv3_linear_weights"
]
skip_prune_params = skip_prune_params + ['conv2d_53.w_0']
else:
skip_prune_params = None
def get_pruned_params(parameters):
params = []
for param in parameters:
if len(
param.shape
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
if param.name not in skip_prune_params:
params.append(param.name)
params.append(param.name)
return params
@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer):
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
# build metric
eval_class = build_metric(config['Metric'])
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer):
logger.info(f"metric['hmean']: {metric['hmean']}")
return metric['hmean']
pruner.sensitive(
params_sensitive = pruner.sensitive(
eval_func=eval_fn,
sen_file="./sen.pickle",
skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
])
params = get_pruned_params(model.parameters())
ratios = {}
# set the prune ratio is 0.2
for param in params:
ratios[param] = 0.2
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
plan = pruner.prune_vars(ratios, [0])
plan = pruner.prune_vars(params_sensitive, [0])
for param in model.parameters():
if ("weights" in param.name and "conv" in param.name) or (
"w_0" in param.name and "conv2d" in param.name):
@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}")
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)