fix export model for distillation model
This commit is contained in:
parent
ab4db2acce
commit
6361a38ff5
|
@ -82,7 +82,7 @@ class DistillationDistanceLoss(DistanceLoss):
|
|||
key=None,
|
||||
name="loss_distance",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, name=name)
|
||||
super().__init__(mode=mode, name=name, **kargs)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
|
|
|
@ -34,8 +34,8 @@ class DistillationModel(nn.Layer):
|
|||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_dict = dict()
|
||||
index = 0
|
||||
self.model_list = []
|
||||
self.model_name_list = []
|
||||
for key in config["Models"]:
|
||||
model_config = config["Models"][key]
|
||||
freeze_params = False
|
||||
|
@ -46,15 +46,15 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
load_dygraph_pretrain(model, path=pretrained[index])
|
||||
load_dygraph_pretrain(model, path=pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
self.model_dict[key] = self.add_sublayer(key, model)
|
||||
index += 1
|
||||
self.model_list.append(self.add_sublayer(key, model))
|
||||
self.model_name_list.append(key)
|
||||
|
||||
def forward(self, x):
|
||||
result_dict = dict()
|
||||
for key in self.model_dict:
|
||||
result_dict[key] = self.model_dict[key](x)
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
result_dict[model_name] = self.model_list[idx](x)
|
||||
return result_dict
|
||||
|
|
|
@ -17,7 +17,7 @@ import sys
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
|
||||
|
||||
import argparse
|
||||
|
||||
|
@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
|
|||
from tools.program import load_config, merge_config, ArgsParser
|
||||
|
||||
|
||||
def main():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
logger = get_logger()
|
||||
# build post process
|
||||
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
config['Global'])
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
init_model(config, model, logger)
|
||||
model.eval()
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
max_text_length = config['Architecture']['Head']['max_text_length']
|
||||
def export_single_model(model, arch_config, save_path, logger):
|
||||
if arch_config["algorithm"] == "SRN":
|
||||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||
shape=[None, 1, 64, 256], dtype="float32"), [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 256, 1],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
|
@ -71,24 +51,66 @@ def main():
|
|||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if config['Architecture']['model_type'] == "rec":
|
||||
if arch_config["model_type"] == "rec":
|
||||
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||
if 'Transform' in config['Architecture'] and config['Architecture'][
|
||||
'Transform'] is not None and config['Architecture'][
|
||||
'Transform']['name'] == 'TPS':
|
||||
if "Transform" in arch_config and arch_config[
|
||||
"Transform"] is not None and arch_config["Transform"][
|
||||
"name"] == "TPS":
|
||||
logger.info(
|
||||
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
|
||||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||
)
|
||||
infer_shape[-1] = 100
|
||||
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
shape=[None] + infer_shape, dtype="float32")
|
||||
])
|
||||
|
||||
paddle.jit.save(model, save_path)
|
||||
logger.info('inference model is saved to {}'.format(save_path))
|
||||
logger.info("inference model is saved to {}".format(save_path))
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
logger = get_logger()
|
||||
# build post process
|
||||
|
||||
post_process_class = build_post_process(config["PostProcess"],
|
||||
config["Global"])
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, "character"):
|
||||
char_num = len(getattr(post_process_class, "character"))
|
||||
if config["Architecture"]["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config["Architecture"]["Models"]:
|
||||
config["Architecture"]["Models"][key]["Head"][
|
||||
"out_channels"] = char_num
|
||||
else: # base rec model
|
||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||
model = build_model(config["Architecture"])
|
||||
init_model(config, model, logger)
|
||||
model.eval()
|
||||
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
archs = list(arch_config["Models"].values())
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(model.model_list[idx], archs[idx],
|
||||
sub_model_save_path, logger)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(model, arch_config, save_path, logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue