This commit is contained in:
WenmuZhou 2021-01-28 13:21:10 +08:00
parent b30add8ae5
commit 8e697e349f
1 changed files with 4 additions and 1 deletions

View File

@ -54,10 +54,13 @@ def main():
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1]
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(