add tensorrt args

This commit is contained in:
LDOUBLEV 2020-12-18 15:27:44 +08:00
parent ec37732512
commit 9039cca26d
2 changed files with 9 additions and 1 deletions

View File

@ -35,6 +35,7 @@ logger = get_logger()
class TextDetector(object):
def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
pre_process_list = [{

View File

@ -33,6 +33,8 @@ def parse_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
@ -46,7 +48,7 @@ def parse_args():
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10)
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
@ -113,6 +115,11 @@ def create_predictor(args, mode, logger):
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=AnalysisConfig.Precision.Half
if args.use_fp16 else AnalysisConfig.Precision.Float32,
max_batch_size=args.max_batch_size)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)