add params min_subgraph_size

This commit is contained in:
LDOUBLEV 2021-06-10 14:29:06 +08:00
parent 616ad6a179
commit 96a53fb4be
1 changed files with 3 additions and 1 deletions

View File

@ -37,6 +37,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=3)
parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500) parser.add_argument("--gpu_mem", type=int, default=500)
@ -235,7 +236,8 @@ def create_predictor(args, mode, logger):
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=inference.PrecisionType.Float32, precision_mode=inference.PrecisionType.Float32,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
min_subgraph_size=3) # skip the minmum trt subgraph min_subgraph_size=args.
min_subgraph_size) # skip the minmum trt subgraph
if mode == "det": if mode == "det":
min_input_shape = { min_input_shape = {
"x": [1, 3, 50, 50], "x": [1, 3, 50, 50],