add max_text_length to export model
This commit is contained in:
parent
38fc1fae63
commit
b566bcbbaa
|
@ -47,23 +47,25 @@ def main():
|
|||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
init_model(config, model, logger)
|
||||
# init_model(config, model, logger)
|
||||
model.eval()
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
max_text_length = config['Architecture']['Head']['max_text_length']
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 256, 1],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
shape=[None, 25, 1],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
shape=[None, 8, 25, 25], dtype="int64"),
|
||||
shape=[None, max_text_length, 1], dtype="int64"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 8, 25, 25], dtype="int64")
|
||||
shape=[None, 8, max_text_length, max_text_length],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
shape=[None, 8, max_text_length, max_text_length],
|
||||
dtype="int64")
|
||||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
|
|
Loading…
Reference in New Issue