Merge branch 'dygraph' into autolog

This commit is contained in:
Double_V 2021-06-28 20:34:37 +08:00 committed by GitHub
commit e4d49819e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 110 additions and 95 deletions

View File

@ -230,15 +230,8 @@ class GridGenerator(nn.Layer):
def build_inv_delta_C_paddle(self, C): def build_inv_delta_C_paddle(self, C):
""" Return inv_delta_C which is needed to calculate T """ """ Return inv_delta_C which is needed to calculate T """
F = self.F F = self.F
hat_C = paddle.zeros((F, F), dtype='float64') # F x F hat_eye = paddle.eye(F, dtype='float64') # F x F
for i in range(0, F): hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
for j in range(i, F):
if i == j:
hat_C[i, j] = 1
else:
r = paddle.norm(C[i] - C[j])
hat_C[i, j] = r
hat_C[j, i] = r
hat_C = (hat_C**2) * paddle.log(hat_C) hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3 delta_C = paddle.concat( # F+3 x F+3
[ [

View File

@ -30,22 +30,32 @@ Types 1-4 follow the traditional OCR process, and 5 follow the Table OCR process
[doc](table/README.md) [doc](table/README.md)
## 4. PaddleStructure whl package introduction ## 4. Predictive by inference engine
### 4.1 Use Use the following commands to complete the inference
4.1.1 Use by code
```python ```python
python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel, and the excel file name will be the coordinates of the table in the image.
## 5. PaddleStructure whl package introduction
### 5.1 Use
5.1.1 Use by code
```python
import os
import cv2 import cv2
from paddlestructure import PaddleStructure,draw_result from paddlestructure import PaddleStructure,draw_result,save_res
table_engine = PaddleStructure( table_engine = PaddleStructure(show_log=True)
output='./output/table',
show_log=True)
save_folder = './output/table'
img_path = '../doc/table/1.png' img_path = '../doc/table/1.png'
img = cv2.imread(img_path) img = cv2.imread(img_path)
result = table_engine(img) result = table_engine(img)
save_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result: for line in result:
print(line) print(line)
@ -58,19 +68,19 @@ im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
4.1.2 Use by command line 5.1.2 Use by command line
```bash ```bash
paddlestructure --image_dir=../doc/table/1.png paddlestructure --image_dir=../doc/table/1.png
``` ```
### 参数说明 ### Parameter Description
大部分参数和paddleocr whl包保持一致见 [whl包文档](../doc/doc_ch/whl.md) Most of the parameters are consistent with the paddleocr whl package, see [whl package documentation](../doc/doc_ch/whl.md)
| 字段 | 说明 | 默认值 | | Parameter | Description | Default |
|------------------------|------------------------------------------------------|------------------| |------------------------|------------------------------------------------------|------------------|
| output | excel和识别结果保存的地址 | ./output/table | | output | The path where excel and recognition results are saved | ./output/table |
| structure_max_len | structure模型预测时图像的长边resize尺度 | 488 | | structure_max_len | When the table structure model predicts, the long side of the image is resized | 488 |
| structure_model_dir | structure inference 模型地址 | None | | structure_model_dir | Table structure inference model path | None |
| structure_char_type | structure 模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx | | structure_char_type | Dictionary path used by table structure model | ../ppocr/utils/dict/table_structure_dict.tx |

View File

@ -30,22 +30,32 @@ PaddleStructure 是一个用于复杂板式文字OCR的工具包流程如下
[文档](table/README_ch.md) [文档](table/README_ch.md)
## 4. PaddleStructure whl包介绍 ## 4. 预测引擎推理
### 4.1 使用 使用如下命令即可完成预测引擎的推理
4.1.1 代码使用
```python ```python
python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后每张图片会output字段指定的目录下有一个同名目录图片里的每个表格会存储为一个excelexcel文件名为表格在图片里的坐标。
## 5. PaddleStructure whl包介绍
### 5.1 使用
5.1.1 代码使用
```python
import os
import cv2 import cv2
from paddlestructure import PaddleStructure,draw_result from paddlestructure import PaddleStructure,draw_result,save_res
table_engine = PaddleStructure( table_engine = PaddleStructure(show_log=True)
output='./output/table',
show_log=True)
save_folder = './output/table'
img_path = '../doc/table/1.png' img_path = '../doc/table/1.png'
img = cv2.imread(img_path) img = cv2.imread(img_path)
result = table_engine(img) result = table_engine(img)
save_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result: for line in result:
print(line) print(line)
@ -58,7 +68,7 @@ im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
4.1.2 命令行使用 5.1.2 命令行使用
```bash ```bash
paddlestructure --image_dir=../doc/table/1.png paddlestructure --image_dir=../doc/table/1.png
``` ```
@ -69,8 +79,8 @@ paddlestructure --image_dir=../doc/table/1.png
| 字段 | 说明 | 默认值 | | 字段 | 说明 | 默认值 |
|------------------------|------------------------------------------------------|------------------| |------------------------|------------------------------------------------------|------------------|
| output | excel和识别结果保存的地址 | ./output/table | | output | excel和识别结果保存的地址 | ./output/table |
| structure_max_len | structure模型预测时图像的长边resize尺度 | 488 | | table_max_len | 表格结构模型预测时图像的长边resize尺度 | 488 |
| structure_model_dir | structure inference 模型地址 | None | | table_model_dir | 表格结构模型 inference 模型地址 | None |
| structure_char_type | structure 模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx | | table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |

View File

@ -32,7 +32,7 @@ logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
__all__ = ['PaddleStructure', 'draw_result', 'to_excel'] __all__ = ['PaddleStructure', 'draw_result', 'save_res']
VERSION = '2.1' VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/") BASE_DIR = os.path.expanduser("~/.paddlestructure/")
@ -40,7 +40,7 @@ BASE_DIR = os.path.expanduser("~/.paddlestructure/")
model_urls = { model_urls = {
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar', 'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar', 'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar' 'table': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
} }
@ -51,7 +51,7 @@ def parse_args(mMain=True):
parser.add_help = mMain parser.add_help = mMain
for action in parser._actions: for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']: if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
action.default = None action.default = None
if mMain: if mMain:
return parser.parse_args() return parser.parse_args()
@ -76,13 +76,13 @@ class PaddleStructure(OCRSystem):
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec'), os.path.join(BASE_DIR, VERSION, 'rec'),
model_urls['rec']) model_urls['rec'])
params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir, params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'structure'), os.path.join(BASE_DIR, VERSION, 'table'),
model_urls['structure']) model_urls['table'])
# download model # download model
maybe_download(params.det_model_dir, det_url) maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url) maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.structure_model_dir, structure_url) maybe_download(params.table_model_dir, table_url)
if params.rec_char_dict_path is None: if params.rec_char_dict_path is None:
params.rec_char_type = 'EN' params.rec_char_type = 'EN'
@ -90,12 +90,12 @@ class PaddleStructure(OCRSystem):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt') params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else: else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt') params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.structure_char_dict_path is None: if params.table_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')): if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.structure_char_dict_path = str( params.table_char_dict_path = str(
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt') Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else: else:
params.structure_char_dict_path = str( params.table_char_dict_path = str(
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt') Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
print(params) print(params)
@ -145,4 +145,24 @@ def main():
for item in result: for item in result:
logger.info(item['res']) logger.info(item['res'])
save_res(result, save_folder, img_name) save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
if __name__ == '__main__':
table_engine = PaddleStructure(show_log=True)
img_path = '../test/test_imgs/PMC1173095_006_00.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_res(result, '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
os.path.basename(img_path).split('.')[0])
for line in result:
print(line)
from PIL import Image
font_path = '../doc/fonts/simfang.ttf'
image = Image.open(img_path).convert('RGB')
im_show = draw_result(image, result, font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')

View File

@ -36,7 +36,7 @@ In gt json, the key is the image name, the value is the corresponding gt, and gt
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output. Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
```python ```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
``` ```
@ -44,6 +44,6 @@ python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_di
First cd to the PaddleOCR/ppstructure directory First cd to the PaddleOCR/ppstructure directory
```python ```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
``` ```
After running, the excel sheet of each picture will be saved in the directory specified by the table_output field After running, the excel sheet of each picture will be saved in the directory specified by the output field

View File

@ -36,7 +36,7 @@ json 中key为图片名value为对于的gtgt是一个由四个item组
准备完成后使用如下命令进行评估评估完成后会输出teds指标。 准备完成后使用如下命令进行评估评估完成后会输出teds指标。
```python ```python
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
``` ```
@ -44,6 +44,6 @@ python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_di
先cd到PaddleOCR/ppstructure目录下 先cd到PaddleOCR/ppstructure目录下
```python ```python
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --structure_model_dir=path/to/structure_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --structure_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
``` ```
运行完成后每张图片的excel表格会保存到table_output字段指定的目录下 运行完成后每张图片的excel表格会保存到output字段指定的目录下

View File

@ -41,7 +41,7 @@ class TableStructurer(object):
def __init__(self, args): def __init__(self, args):
pre_process_list = [{ pre_process_list = [{
'ResizeTableImage': { 'ResizeTableImage': {
'max_len': args.structure_max_len 'max_len': args.table_max_len
} }
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
@ -61,14 +61,14 @@ class TableStructurer(object):
}] }]
postprocess_params = { postprocess_params = {
'name': 'TableLabelDecode', 'name': 'TableLabelDecode',
"character_type": args.structure_char_type, "character_type": args.table_char_type,
"character_dict_path": args.structure_char_dict_path, "character_dict_path": args.table_char_dict_path,
} }
self.preprocess_op = create_operators(pre_process_list) self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'structure', logger) utility.create_predictor(args, 'table', logger)
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()

View File

@ -23,10 +23,10 @@ def init_args():
# params for output # params for output
parser.add_argument("--output", type=str, default='./output/table') parser.add_argument("--output", type=str, default='./output/table')
# params for table structure # params for table structure
parser.add_argument("--structure_max_len", type=int, default=488) parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--structure_model_dir", type=str) parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--structure_char_type", type=str, default='en') parser.add_argument("--table_char_type", type=str, default='en')
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt") parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
return parser return parser

View File

@ -257,7 +257,8 @@ if __name__ == "__main__":
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save, img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure)) "det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
text_detector.autolog.report() text_detector.autolog.report()

View File

@ -322,7 +322,8 @@ def main(args):
'total_time_s': rec_time_dict['total_time'] 'total_time_s': rec_time_dict['total_time']
} }
benchmark_log = benchmark_utils.PaddleInferBenchmark( benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_recognizer.config, model_info, data_info, perf_info, mems) text_recognizer.config, model_info, data_info, perf_info, mems,
args.save_log_path)
benchmark_log("Rec") benchmark_log("Rec")

View File

@ -37,6 +37,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=3)
parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500) parser.add_argument("--gpu_mem", type=int, default=500)
@ -201,8 +202,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir model_dir = args.cls_model_dir
elif mode == 'rec': elif mode == 'rec':
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
elif mode == 'structure': elif mode == 'table':
model_dir = args.structure_model_dir model_dir = args.table_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
@ -236,12 +237,14 @@ def create_predictor(args, mode, logger):
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=inference.PrecisionType.Float32, precision_mode=inference.PrecisionType.Float32,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
min_subgraph_size=3) # skip the minmum trt subgraph min_subgraph_size=args.min_subgraph_size)
if mode == "det" and "mobile" in model_file_path: # skip the minmum trt subgraph
if mode == "det":
min_input_shape = { min_input_shape = {
"x": [1, 3, 50, 50], "x": [1, 3, 50, 50],
"conv2d_92.tmp_0": [1, 96, 20, 20], "conv2d_92.tmp_0": [1, 96, 20, 20],
"conv2d_91.tmp_0": [1, 96, 10, 10], "conv2d_91.tmp_0": [1, 96, 10, 10],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_1.tmp_0": [1, 96, 10, 10], "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
@ -254,6 +257,7 @@ def create_predictor(args, mode, logger):
"x": [1, 3, 2000, 2000], "x": [1, 3, 2000, 2000],
"conv2d_92.tmp_0": [1, 96, 400, 400], "conv2d_92.tmp_0": [1, 96, 400, 400],
"conv2d_91.tmp_0": [1, 96, 200, 200], "conv2d_91.tmp_0": [1, 96, 200, 200],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_1.tmp_0": [1, 96, 200, 200], "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
@ -266,6 +270,7 @@ def create_predictor(args, mode, logger):
"x": [1, 3, 640, 640], "x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 96, 160, 160], "conv2d_92.tmp_0": [1, 96, 160, 160],
"conv2d_91.tmp_0": [1, 96, 80, 80], "conv2d_91.tmp_0": [1, 96, 80, 80],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_1.tmp_0": [1, 96, 80, 80], "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
@ -274,31 +279,6 @@ def create_predictor(args, mode, logger):
"elementwise_add_7": [1, 56, 40, 40], "elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 96, 40, 40] "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
} }
if mode == "det" and "server" in model_file_path:
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
}
max_input_shape = {
"x": [1, 3, 2000, 2000],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
}
elif mode == "rec": elif mode == "rec":
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
@ -331,7 +311,7 @@ def create_predictor(args, mode, logger):
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if mode == 'structure': if mode == 'table':
config.delete_pass("fc_fuse_pass") # not supported for table config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
config.switch_ir_optim(True) config.switch_ir_optim(True)

View File

@ -112,4 +112,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main() main()