revert prune
This commit is contained in:
parent
b79fee116a
commit
76f404690e
|
@ -24,14 +24,6 @@ sys.path.append(__dir__)
|
|||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import json
|
||||
import cv2
|
||||
import paddle
|
||||
from paddle import fluid
|
||||
import paddleslim as slim
|
||||
from copy import deepcopy
|
||||
from tools import program
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from ppocr.data import build_dataloader
|
||||
|
@ -46,28 +38,14 @@ import tools.program as program
|
|||
dist.get_world_size()
|
||||
|
||||
|
||||
def get_pruned_params(parameters, mode="det"):
|
||||
if mode == "det":
|
||||
skip_prune_params = [
|
||||
"conv2d_56.w_0", "conv2d_54.w_0", "conv2d_51.w_0",
|
||||
"conv_last_weights", "conv14_linear_weights",
|
||||
"conv13_expand_weights", "conv12_linear_weights",
|
||||
"conv12_expand_weights", "conv7_expand_weights",
|
||||
"conv8_expand_weights", "conv8_linear_weights",
|
||||
"conv5_linear_weights", "conv5_expand_weights",
|
||||
"conv3_linear_weights"
|
||||
]
|
||||
skip_prune_params = skip_prune_params + ['conv2d_53.w_0']
|
||||
else:
|
||||
skip_prune_params = None
|
||||
def get_pruned_params(parameters):
|
||||
params = []
|
||||
|
||||
for param in parameters:
|
||||
if len(
|
||||
param.shape
|
||||
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
|
||||
if param.name not in skip_prune_params:
|
||||
params.append(param.name)
|
||||
params.append(param.name)
|
||||
return params
|
||||
|
||||
|
||||
|
@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer):
|
|||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
|
||||
|
@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer):
|
|||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||
return metric['hmean']
|
||||
|
||||
pruner.sensitive(
|
||||
params_sensitive = pruner.sensitive(
|
||||
eval_func=eval_fn,
|
||||
sen_file="./sen.pickle",
|
||||
skip_vars=[
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
|
||||
])
|
||||
|
||||
params = get_pruned_params(model.parameters())
|
||||
ratios = {}
|
||||
# set the prune ratio is 0.2
|
||||
for param in params:
|
||||
ratios[param] = 0.2
|
||||
logger.info(
|
||||
"The sensitivity analysis results of model parameters saved in sen.pickle"
|
||||
)
|
||||
# calculate pruned params's ratio
|
||||
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
|
||||
for key in params_sensitive.keys():
|
||||
logger.info(f"{key}, {params_sensitive[key]}")
|
||||
|
||||
plan = pruner.prune_vars(ratios, [0])
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
for param in model.parameters():
|
||||
if ("weights" in param.name and "conv" in param.name) or (
|
||||
"w_0" in param.name and "conv2d" in param.name):
|
||||
|
@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer):
|
|||
logger.info(f"FLOPs after pruning: {flops}")
|
||||
|
||||
# start train
|
||||
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||
|
|
Loading…
Reference in New Issue