fix export model for distillation model

This commit is contained in:
littletomatodonkey 2021-06-03 06:53:24 +00:00
parent ab4db2acce
commit 6361a38ff5
3 changed files with 62 additions and 40 deletions

View File

@ -82,7 +82,7 @@ class DistillationDistanceLoss(DistanceLoss):
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, name=name)
super().__init__(mode=mode, name=name, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs

View File

@ -34,8 +34,8 @@ class DistillationModel(nn.Layer):
config (dict): the super parameters for module.
"""
super().__init__()
self.model_dict = dict()
index = 0
self.model_list = []
self.model_name_list = []
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
@ -46,15 +46,15 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
load_dygraph_pretrain(model, path=pretrained[index])
load_dygraph_pretrain(model, path=pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_dict[key] = self.add_sublayer(key, model)
index += 1
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x):
result_dict = dict()
for key in self.model_dict:
result_dict[key] = self.model_dict[key](x)
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x)
return result_dict

View File

@ -17,7 +17,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse
@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
def export_single_model(model, arch_config, save_path, logger):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [
shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
@ -71,24 +51,66 @@ def main():
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
if "Transform" in arch_config and arch_config[
"Transform"] is not None and arch_config["Transform"][
"name"] == "TPS":
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
shape=[None] + infer_shape, dtype="float32")
])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
logger.info("inference model is saved to {}".format(save_path))
return
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"],
config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model, logger)
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__":