add tensorrt args
This commit is contained in:
parent
ec37732512
commit
9039cca26d
|
@ -35,6 +35,7 @@ logger = get_logger()
|
|||
|
||||
class TextDetector(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.det_algorithm = args.det_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
pre_process_list = [{
|
||||
|
|
|
@ -33,6 +33,8 @@ def parse_args():
|
|||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
|
||||
# params for text detector
|
||||
|
@ -46,7 +48,7 @@ def parse_args():
|
|||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
||||
|
||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
|
@ -113,6 +115,11 @@ def create_predictor(args, mode, logger):
|
|||
|
||||
if args.use_gpu:
|
||||
config.enable_use_gpu(args.gpu_mem, 0)
|
||||
if args.use_tensorrt:
|
||||
config.enable_tensorrt_engine(
|
||||
precision_mode=AnalysisConfig.Precision.Half
|
||||
if args.use_fp16 else AnalysisConfig.Precision.Float32,
|
||||
max_batch_size=args.max_batch_size)
|
||||
else:
|
||||
config.disable_gpu()
|
||||
config.set_cpu_math_library_num_threads(6)
|
||||
|
|
Loading…
Reference in New Issue