Merge pull request #3250 from littletomatodonkey/dev/add_quant_dist
add support for quant distillation
This commit is contained in:
commit
0ed4aee5b5
|
@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def export_single_model(quanter, model, infer_shape, save_path, logger):
|
||||||
|
quanter.save_quantized_model(
|
||||||
|
model,
|
||||||
|
save_path,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None] + infer_shape, dtype='float32')
|
||||||
|
])
|
||||||
|
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
############################################################################################################
|
############################################################################################################
|
||||||
# 1. quantization configs
|
# 1. quantization configs
|
||||||
|
@ -76,7 +87,14 @@ def main():
|
||||||
# for rec algorithm
|
# for rec algorithm
|
||||||
if hasattr(post_process_class, 'character'):
|
if hasattr(post_process_class, 'character'):
|
||||||
char_num = len(getattr(post_process_class, 'character'))
|
char_num = len(getattr(post_process_class, 'character'))
|
||||||
|
if config['Architecture']["algorithm"] in ["Distillation",
|
||||||
|
]: # distillation model
|
||||||
|
for key in config['Architecture']["Models"]:
|
||||||
|
config['Architecture']["Models"][key]["Head"][
|
||||||
|
'out_channels'] = char_num
|
||||||
|
else: # base rec model
|
||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
# get QAT model
|
# get QAT model
|
||||||
|
@ -93,24 +111,27 @@ def main():
|
||||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||||
|
|
||||||
# start eval
|
# start eval
|
||||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
model_type = config['Architecture']['model_type']
|
||||||
eval_class)
|
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||||
|
eval_class, model_type)
|
||||||
logger.info('metric eval ***************')
|
logger.info('metric eval ***************')
|
||||||
for k, v in metirc.items():
|
for k, v in metric.items():
|
||||||
logger.info('{}:{}'.format(k, v))
|
logger.info('{}:{}'.format(k, v))
|
||||||
|
|
||||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
|
||||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
infer_shape = [3, 32, 100] if config['Architecture'][
|
||||||
'model_type'] != "det" else [3, 640, 640]
|
'model_type'] != "det" else [3, 640, 640]
|
||||||
|
|
||||||
quanter.save_quantized_model(
|
save_path = config["Global"]["save_inference_dir"]
|
||||||
model,
|
|
||||||
save_path,
|
arch_config = config["Architecture"]
|
||||||
input_spec=[
|
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||||
paddle.static.InputSpec(
|
for idx, name in enumerate(model.model_name_list):
|
||||||
shape=[None] + infer_shape, dtype='float32')
|
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||||
])
|
export_single_model(quanter, model.model_list[idx], infer_shape,
|
||||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
sub_model_save_path, logger)
|
||||||
|
else:
|
||||||
|
save_path = os.path.join(save_path, "inference")
|
||||||
|
export_single_model(quanter, model, infer_shape, save_path, logger)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
|
||||||
# for rec algorithm
|
# for rec algorithm
|
||||||
if hasattr(post_process_class, 'character'):
|
if hasattr(post_process_class, 'character'):
|
||||||
char_num = len(getattr(post_process_class, 'character'))
|
char_num = len(getattr(post_process_class, 'character'))
|
||||||
|
if config['Architecture']["algorithm"] in ["Distillation",
|
||||||
|
]: # distillation model
|
||||||
|
for key in config['Architecture']["Models"]:
|
||||||
|
config['Architecture']["Models"][key]["Head"][
|
||||||
|
'out_channels'] = char_num
|
||||||
|
else: # base rec model
|
||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
|
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||||
|
quanter.quantize(model)
|
||||||
|
|
||||||
if config['Global']['distributed']:
|
if config['Global']['distributed']:
|
||||||
model = paddle.DataParallel(model)
|
model = paddle.DataParallel(model)
|
||||||
|
|
||||||
|
@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
|
||||||
|
|
||||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||||
format(len(train_dataloader), len(valid_dataloader)))
|
format(len(train_dataloader), len(valid_dataloader)))
|
||||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
|
||||||
quanter.quantize(model)
|
|
||||||
|
|
||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
|
|
Loading…
Reference in New Issue