Merge pull request #2678 from WenmuZhou/fix_srn_post_process

add max_text_length to export model
This commit is contained in:
MissPenguin 2021-04-28 14:36:51 +08:00 committed by GitHub
commit 3b19311dc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 4 deletions

View File

@ -53,17 +53,19 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 25, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64"),
shape=[None, max_text_length, 1], dtype="int64"),
paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64")
shape=[None, 8, max_text_length, max_text_length],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length],
dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)