From b566bcbbaa432278658f15e05172e362a0c84692 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 28 Apr 2021 13:36:16 +0800 Subject: [PATCH 1/2] add max_text_length to export model --- tools/export_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tools/export_model.py b/tools/export_model.py index f587b2bb..37e7ba61 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -47,23 +47,25 @@ def main(): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - init_model(config, model, logger) + # init_model(config, model, logger) model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir']) if config['Architecture']['algorithm'] == "SRN": + max_text_length = config['Architecture']['Head']['max_text_length'] other_shape = [ paddle.static.InputSpec( shape=[None, 1, 64, 256], dtype='float32'), [ paddle.static.InputSpec( shape=[None, 256, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 25, 1], - dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64"), + shape=[None, max_text_length, 1], dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, 25, 25], dtype="int64") + shape=[None, 8, max_text_length, max_text_length], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, max_text_length, max_text_length], + dtype="int64") ] ] model = to_static(model, input_spec=other_shape) From d36c19ce1a1b7431df1b65ab95ee0c7f2296bcde Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 28 Apr 2021 13:41:04 +0800 Subject: [PATCH 2/2] add max_text_length to export model --- tools/export_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/export_model.py b/tools/export_model.py index 37e7ba61..bdff89f7 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -47,7 +47,7 @@ def main(): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) - # init_model(config, model, logger) + init_model(config, model, logger) model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir'])