Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into test
This commit is contained in:
commit
7c6309db7a
|
@ -1,7 +1,7 @@
|
||||||
include LICENSE.txt
|
include LICENSE
|
||||||
include README.md
|
include README.md
|
||||||
|
|
||||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||||
recursive-include ppocr/data/ *.py
|
recursive-include ppocr/data/ *.py
|
||||||
recursive-include ppocr/postprocess *.py
|
recursive-include ppocr/postprocess *.py
|
||||||
recursive-include tools/infer *.py
|
recursive-include tools/infer *.py
|
||||||
|
|
|
@ -465,8 +465,12 @@ public class MainActivity extends AppCompatActivity {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void btn_load_model_click(View view) {
|
public void btn_load_model_click(View view) {
|
||||||
tvStatus.setText("STATUS: load model ......");
|
if (predictor.isLoaded()){
|
||||||
loadModel();
|
tvStatus.setText("STATUS: model has been loaded");
|
||||||
|
}else{
|
||||||
|
tvStatus.setText("STATUS: load model ......");
|
||||||
|
loadModel();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void btn_run_model_click(View view) {
|
public void btn_run_model_click(View view) {
|
||||||
|
|
|
@ -194,26 +194,25 @@ public class Predictor {
|
||||||
"supported!");
|
"supported!");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int[] channelStride = new int[]{width * height, width * height * 2};
|
|
||||||
int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1);
|
|
||||||
for (int y = 0; y < height; y++) {
|
|
||||||
for (int x = 0; x < width; x++) {
|
|
||||||
int color = scaleImage.getPixel(x, y);
|
|
||||||
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
|
|
||||||
(float) blue(color) / 255.0f};
|
|
||||||
inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
|
|
||||||
inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
|
|
||||||
inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
|
|
||||||
|
|
||||||
}
|
int[] channelStride = new int[]{width * height, width * height * 2};
|
||||||
|
int[] pixels=new int[width*height];
|
||||||
|
scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
|
||||||
|
for (int i = 0; i < pixels.length; i++) {
|
||||||
|
int color = pixels[i];
|
||||||
|
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
|
||||||
|
(float) blue(color) / 255.0f};
|
||||||
|
inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
|
||||||
|
inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
|
||||||
|
inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
|
||||||
}
|
}
|
||||||
} else if (channels == 1) {
|
} else if (channels == 1) {
|
||||||
for (int y = 0; y < height; y++) {
|
int[] pixels=new int[width*height];
|
||||||
for (int x = 0; x < width; x++) {
|
scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
|
||||||
int color = inputImage.getPixel(x, y);
|
for (int i = 0; i < pixels.length; i++) {
|
||||||
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
|
int color = pixels[i];
|
||||||
inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0];
|
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
|
||||||
}
|
inputData[i] = (gray - inputMean[0]) / inputStd[0];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " +
|
Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " +
|
||||||
|
|
BIN
doc/joinus.PNG
BIN
doc/joinus.PNG
Binary file not shown.
Before Width: | Height: | Size: 102 KiB After Width: | Height: | Size: 78 KiB |
Binary file not shown.
After Width: | Height: | Size: 263 KiB |
129
paddleocr.py
129
paddleocr.py
|
@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
|
||||||
sys.path.append(os.path.join(__dir__, ''))
|
sys.path.append(os.path.join(__dir__, ''))
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tarfile
|
|
||||||
import requests
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from tools.infer import predict_system
|
from tools.infer import predict_system
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||||
|
from ppocr.utils.network import maybe_download, download_with_progressbar
|
||||||
from tools.infer.utility import draw_ocr, init_args, str2bool
|
from tools.infer.utility import draw_ocr, init_args, str2bool
|
||||||
|
|
||||||
__all__ = ['PaddleOCR']
|
__all__ = ['PaddleOCR']
|
||||||
|
@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
|
||||||
model_urls = {
|
model_urls = {
|
||||||
'det': {
|
'det': {
|
||||||
'ch':
|
'ch':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||||
'en':
|
'en':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||||
},
|
},
|
||||||
'rec': {
|
'rec': {
|
||||||
'ch': {
|
'ch': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
||||||
},
|
},
|
||||||
'en': {
|
'en': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/en_dict.txt'
|
'dict_path': './ppocr/utils/en_dict.txt'
|
||||||
},
|
},
|
||||||
'french': {
|
'french': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
||||||
},
|
},
|
||||||
'german': {
|
'german': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
||||||
},
|
},
|
||||||
'korean': {
|
'korean': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
||||||
},
|
},
|
||||||
'japan': {
|
'japan': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
||||||
},
|
},
|
||||||
'chinese_cht': {
|
'chinese_cht': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
|
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
|
||||||
},
|
},
|
||||||
'ta': {
|
'ta': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/ta_dict.txt'
|
'dict_path': './ppocr/utils/dict/ta_dict.txt'
|
||||||
},
|
},
|
||||||
'te': {
|
'te': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/te_dict.txt'
|
'dict_path': './ppocr/utils/dict/te_dict.txt'
|
||||||
},
|
},
|
||||||
'ka': {
|
'ka': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/ka_dict.txt'
|
'dict_path': './ppocr/utils/dict/ka_dict.txt'
|
||||||
},
|
},
|
||||||
'latin': {
|
'latin': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/latin_dict.txt'
|
'dict_path': './ppocr/utils/dict/latin_dict.txt'
|
||||||
},
|
},
|
||||||
'arabic': {
|
'arabic': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
|
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
|
||||||
},
|
},
|
||||||
'cyrillic': {
|
'cyrillic': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
|
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
|
||||||
},
|
},
|
||||||
'devanagari': {
|
'devanagari': {
|
||||||
'url':
|
'url':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||||
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
|
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
'cls':
|
'cls':
|
||||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORT_DET_MODEL = ['DB']
|
SUPPORT_DET_MODEL = ['DB']
|
||||||
|
@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
|
||||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||||
|
|
||||||
|
|
||||||
def download_with_progressbar(url, save_path):
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
|
||||||
block_size = 1024 # 1 Kibibyte
|
|
||||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
|
||||||
with open(save_path, 'wb') as file:
|
|
||||||
for data in response.iter_content(block_size):
|
|
||||||
progress_bar.update(len(data))
|
|
||||||
file.write(data)
|
|
||||||
progress_bar.close()
|
|
||||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
|
||||||
logger.error("Something went wrong while downloading models")
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_download(model_storage_directory, url):
|
|
||||||
# using custom model
|
|
||||||
tar_file_name_list = [
|
|
||||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
|
||||||
]
|
|
||||||
if not os.path.exists(
|
|
||||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
|
||||||
) or not os.path.exists(
|
|
||||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
|
||||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
|
||||||
print('download {} to {}'.format(url, tmp_path))
|
|
||||||
os.makedirs(model_storage_directory, exist_ok=True)
|
|
||||||
download_with_progressbar(url, tmp_path)
|
|
||||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
|
||||||
for member in tarObj.getmembers():
|
|
||||||
filename = None
|
|
||||||
for tar_file_name in tar_file_name_list:
|
|
||||||
if tar_file_name in member.name:
|
|
||||||
filename = tar_file_name
|
|
||||||
if filename is None:
|
|
||||||
continue
|
|
||||||
file = tarObj.extractfile(member)
|
|
||||||
with open(
|
|
||||||
os.path.join(model_storage_directory, filename),
|
|
||||||
'wb') as f:
|
|
||||||
f.write(file.read())
|
|
||||||
os.remove(tmp_path)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(mMain=True):
|
def parse_args(mMain=True):
|
||||||
import argparse
|
import argparse
|
||||||
parser = init_args()
|
parser = init_args()
|
||||||
|
@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
args:
|
args:
|
||||||
**kwargs: other params show in paddleocr --help
|
**kwargs: other params show in paddleocr --help
|
||||||
"""
|
"""
|
||||||
postprocess_params = parse_args(mMain=False)
|
params = parse_args(mMain=False)
|
||||||
postprocess_params.__dict__.update(**kwargs)
|
params.__dict__.update(**kwargs)
|
||||||
self.use_angle_cls = postprocess_params.use_angle_cls
|
if params.show_log:
|
||||||
lang = postprocess_params.lang
|
logger.setLevel(logging.DEBUG)
|
||||||
|
self.use_angle_cls = params.use_angle_cls
|
||||||
|
lang = params.lang
|
||||||
latin_lang = [
|
latin_lang = [
|
||||||
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
||||||
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
||||||
|
@ -223,46 +180,46 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
lang = "devanagari"
|
lang = "devanagari"
|
||||||
assert lang in model_urls[
|
assert lang in model_urls[
|
||||||
'rec'], 'param lang must in {}, but got {}'.format(
|
'rec'], 'param lang must in {}, but got {}'.format(
|
||||||
model_urls['rec'].keys(), lang)
|
model_urls['rec'].keys(), lang)
|
||||||
if lang == "ch":
|
if lang == "ch":
|
||||||
det_lang = "ch"
|
det_lang = "ch"
|
||||||
else:
|
else:
|
||||||
det_lang = "en"
|
det_lang = "en"
|
||||||
use_inner_dict = False
|
use_inner_dict = False
|
||||||
if postprocess_params.rec_char_dict_path is None:
|
if params.rec_char_dict_path is None:
|
||||||
use_inner_dict = True
|
use_inner_dict = True
|
||||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
params.rec_char_dict_path = model_urls['rec'][lang][
|
||||||
'dict_path']
|
'dict_path']
|
||||||
|
|
||||||
# init model dir
|
# init model dir
|
||||||
if postprocess_params.det_model_dir is None:
|
if params.det_model_dir is None:
|
||||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||||
'det', det_lang)
|
'det', det_lang)
|
||||||
if postprocess_params.rec_model_dir is None:
|
if params.rec_model_dir is None:
|
||||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||||
'rec', lang)
|
'rec', lang)
|
||||||
if postprocess_params.cls_model_dir is None:
|
if params.cls_model_dir is None:
|
||||||
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||||
print(postprocess_params)
|
|
||||||
# download model
|
# download model
|
||||||
maybe_download(postprocess_params.det_model_dir,
|
maybe_download(params.det_model_dir,
|
||||||
model_urls['det'][det_lang])
|
model_urls['det'][det_lang])
|
||||||
maybe_download(postprocess_params.rec_model_dir,
|
maybe_download(params.rec_model_dir,
|
||||||
model_urls['rec'][lang]['url'])
|
model_urls['rec'][lang]['url'])
|
||||||
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
maybe_download(params.cls_model_dir, model_urls['cls'])
|
||||||
|
|
||||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
if params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
|
if params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||||
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if use_inner_dict:
|
if use_inner_dict:
|
||||||
postprocess_params.rec_char_dict_path = str(
|
params.rec_char_dict_path = str(
|
||||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
Path(__file__).parent / params.rec_char_dict_path)
|
||||||
|
|
||||||
|
print(params)
|
||||||
# init det_model and rec_model
|
# init det_model and rec_model
|
||||||
super().__init__(postprocess_params)
|
super().__init__(params)
|
||||||
|
|
||||||
def ocr(self, img, det=True, rec=True, cls=True):
|
def ocr(self, img, det=True, rec=True, cls=True):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -81,7 +81,7 @@ class NormalizeImage(object):
|
||||||
assert isinstance(img,
|
assert isinstance(img,
|
||||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||||
data['image'] = (
|
data['image'] = (
|
||||||
img.astype('float32') * self.scale - self.mean) / self.std
|
img.astype('float32') * self.scale - self.mean) / self.std
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ class DetResizeForTest(object):
|
||||||
img, (ratio_h, ratio_w)
|
img, (ratio_h, ratio_w)
|
||||||
"""
|
"""
|
||||||
limit_side_len = self.limit_side_len
|
limit_side_len = self.limit_side_len
|
||||||
h, w, _ = img.shape
|
h, w, c = img.shape
|
||||||
|
|
||||||
# limit the max side
|
# limit the max side
|
||||||
if self.limit_type == 'max':
|
if self.limit_type == 'max':
|
||||||
|
@ -174,7 +174,7 @@ class DetResizeForTest(object):
|
||||||
ratio = float(limit_side_len) / w
|
ratio = float(limit_side_len) / w
|
||||||
else:
|
else:
|
||||||
ratio = 1.
|
ratio = 1.
|
||||||
else:
|
elif self.limit_type == 'min':
|
||||||
if min(h, w) < limit_side_len:
|
if min(h, w) < limit_side_len:
|
||||||
if h < w:
|
if h < w:
|
||||||
ratio = float(limit_side_len) / h
|
ratio = float(limit_side_len) / h
|
||||||
|
@ -182,6 +182,10 @@ class DetResizeForTest(object):
|
||||||
ratio = float(limit_side_len) / w
|
ratio = float(limit_side_len) / w
|
||||||
else:
|
else:
|
||||||
ratio = 1.
|
ratio = 1.
|
||||||
|
elif self.limit_type == 'resize_long':
|
||||||
|
ratio = float(limit_side_len) / max(h,w)
|
||||||
|
else:
|
||||||
|
raise Exception('not support limit type, image ')
|
||||||
resize_h = int(h * ratio)
|
resize_h = int(h * ratio)
|
||||||
resize_w = int(w * ratio)
|
resize_w = int(w * ratio)
|
||||||
|
|
||||||
|
|
|
@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
|
||||||
self.character_str = string.printable[:-6]
|
self.character_str = string.printable[:-6]
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type in support_character_type:
|
elif character_type in support_character_type:
|
||||||
self.character_str = ""
|
self.character_str = []
|
||||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||||
character_type)
|
character_type)
|
||||||
with open(character_dict_path, "rb") as fin:
|
with open(character_dict_path, "rb") as fin:
|
||||||
lines = fin.readlines()
|
lines = fin.readlines()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||||
self.character_str += line
|
self.character_str.append(line)
|
||||||
if use_space_char:
|
if use_space_char:
|
||||||
self.character_str += " "
|
self.character_str.append(" ")
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@ -288,3 +288,156 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
% beg_or_end
|
% beg_or_end
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class TableLabelDecode(object):
|
||||||
|
""" """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_text_length,
|
||||||
|
max_elem_length,
|
||||||
|
max_cell_num,
|
||||||
|
character_dict_path,
|
||||||
|
**kwargs):
|
||||||
|
self.max_text_length = max_text_length
|
||||||
|
self.max_elem_length = max_elem_length
|
||||||
|
self.max_cell_num = max_cell_num
|
||||||
|
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||||
|
list_character = self.add_special_char(list_character)
|
||||||
|
list_elem = self.add_special_char(list_elem)
|
||||||
|
self.dict_character = {}
|
||||||
|
self.dict_idx_character = {}
|
||||||
|
for i, char in enumerate(list_character):
|
||||||
|
self.dict_idx_character[i] = char
|
||||||
|
self.dict_character[char] = i
|
||||||
|
self.dict_elem = {}
|
||||||
|
self.dict_idx_elem = {}
|
||||||
|
for i, elem in enumerate(list_elem):
|
||||||
|
self.dict_idx_elem[i] = elem
|
||||||
|
self.dict_elem[elem] = i
|
||||||
|
|
||||||
|
def load_char_elem_dict(self, character_dict_path):
|
||||||
|
list_character = []
|
||||||
|
list_elem = []
|
||||||
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
substr = lines[0].decode('utf-8').strip("\n").split("\t")
|
||||||
|
character_num = int(substr[0])
|
||||||
|
elem_num = int(substr[1])
|
||||||
|
for cno in range(1, 1 + character_num):
|
||||||
|
character = lines[cno].decode('utf-8').strip("\n")
|
||||||
|
list_character.append(character)
|
||||||
|
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||||
|
elem = lines[eno].decode('utf-8').strip("\n")
|
||||||
|
list_elem.append(elem)
|
||||||
|
return list_character, list_elem
|
||||||
|
|
||||||
|
def add_special_char(self, list_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||||
|
return list_character
|
||||||
|
|
||||||
|
def get_sp_tokens(self):
|
||||||
|
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||||
|
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||||
|
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||||
|
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||||
|
elem_char_idx1 = self.dict_elem['<td>']
|
||||||
|
elem_char_idx2 = self.dict_elem['<td']
|
||||||
|
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||||
|
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||||
|
self.max_elem_length, self.max_cell_num])
|
||||||
|
return sp_tokens
|
||||||
|
|
||||||
|
def __call__(self, preds):
|
||||||
|
structure_probs = preds['structure_probs']
|
||||||
|
loc_preds = preds['loc_preds']
|
||||||
|
if isinstance(structure_probs,paddle.Tensor):
|
||||||
|
structure_probs = structure_probs.numpy()
|
||||||
|
if isinstance(loc_preds,paddle.Tensor):
|
||||||
|
loc_preds = loc_preds.numpy()
|
||||||
|
structure_idx = structure_probs.argmax(axis=2)
|
||||||
|
structure_probs = structure_probs.max(axis=2)
|
||||||
|
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
||||||
|
structure_probs, 'elem')
|
||||||
|
res_html_code_list = []
|
||||||
|
res_loc_list = []
|
||||||
|
batch_num = len(structure_str)
|
||||||
|
for bno in range(batch_num):
|
||||||
|
res_loc = []
|
||||||
|
for sno in range(len(structure_str[bno])):
|
||||||
|
text = structure_str[bno][sno]
|
||||||
|
if text in ['<td>', '<td']:
|
||||||
|
pos = structure_pos[bno][sno]
|
||||||
|
res_loc.append(loc_preds[bno, pos])
|
||||||
|
res_html_code = ''.join(structure_str[bno])
|
||||||
|
res_loc = np.array(res_loc)
|
||||||
|
res_html_code_list.append(res_html_code)
|
||||||
|
res_loc_list.append(res_loc)
|
||||||
|
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
||||||
|
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
||||||
|
|
||||||
|
def decode(self, text_index, structure_probs, char_or_elem):
|
||||||
|
"""convert text-label into text-index.
|
||||||
|
"""
|
||||||
|
if char_or_elem == "char":
|
||||||
|
current_dict = self.dict_idx_character
|
||||||
|
else:
|
||||||
|
current_dict = self.dict_idx_elem
|
||||||
|
ignored_tokens = self.get_ignored_tokens('elem')
|
||||||
|
beg_idx, end_idx = ignored_tokens
|
||||||
|
|
||||||
|
result_list = []
|
||||||
|
result_pos_list = []
|
||||||
|
result_score_list = []
|
||||||
|
result_elem_idx_list = []
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
elem_pos_list = []
|
||||||
|
elem_idx_list = []
|
||||||
|
score_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
tmp_elem_idx = int(text_index[batch_idx][idx])
|
||||||
|
if idx > 0 and tmp_elem_idx == end_idx:
|
||||||
|
break
|
||||||
|
if tmp_elem_idx in ignored_tokens:
|
||||||
|
continue
|
||||||
|
|
||||||
|
char_list.append(current_dict[tmp_elem_idx])
|
||||||
|
elem_pos_list.append(idx)
|
||||||
|
score_list.append(structure_probs[batch_idx, idx])
|
||||||
|
elem_idx_list.append(tmp_elem_idx)
|
||||||
|
result_list.append(char_list)
|
||||||
|
result_pos_list.append(elem_pos_list)
|
||||||
|
result_score_list.append(score_list)
|
||||||
|
result_elem_idx_list.append(elem_idx_list)
|
||||||
|
return result_list, result_pos_list, result_score_list, result_elem_idx_list
|
||||||
|
|
||||||
|
def get_ignored_tokens(self, char_or_elem):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||||
|
if char_or_elem == "char":
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = self.dict_character[self.beg_str]
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = self.dict_character[self.end_str]
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||||
|
% beg_or_end
|
||||||
|
elif char_or_elem == "elem":
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = self.dict_elem[self.beg_str]
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = self.dict_elem[self.end_str]
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||||
|
% beg_or_end
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in char_or_elem" \
|
||||||
|
% char_or_elem
|
||||||
|
return idx
|
||||||
|
|
|
@ -0,0 +1,277 @@
|
||||||
|
←
|
||||||
|
</overline>
|
||||||
|
☆
|
||||||
|
─
|
||||||
|
α
|
||||||
|
|
||||||
|
|
||||||
|
⋅
|
||||||
|
$
|
||||||
|
ω
|
||||||
|
ψ
|
||||||
|
χ
|
||||||
|
(
|
||||||
|
υ
|
||||||
|
≥
|
||||||
|
σ
|
||||||
|
,
|
||||||
|
ρ
|
||||||
|
ε
|
||||||
|
0
|
||||||
|
■
|
||||||
|
4
|
||||||
|
8
|
||||||
|
✗
|
||||||
|
b
|
||||||
|
<
|
||||||
|
✓
|
||||||
|
Ψ
|
||||||
|
Ω
|
||||||
|
€
|
||||||
|
D
|
||||||
|
3
|
||||||
|
Π
|
||||||
|
H
|
||||||
|
║
|
||||||
|
</strike>
|
||||||
|
L
|
||||||
|
Φ
|
||||||
|
Χ
|
||||||
|
θ
|
||||||
|
P
|
||||||
|
κ
|
||||||
|
λ
|
||||||
|
μ
|
||||||
|
T
|
||||||
|
ξ
|
||||||
|
X
|
||||||
|
β
|
||||||
|
γ
|
||||||
|
δ
|
||||||
|
\
|
||||||
|
ζ
|
||||||
|
η
|
||||||
|
`
|
||||||
|
d
|
||||||
|
<strike>
|
||||||
|
h
|
||||||
|
f
|
||||||
|
l
|
||||||
|
Θ
|
||||||
|
p
|
||||||
|
√
|
||||||
|
t
|
||||||
|
</sub>
|
||||||
|
x
|
||||||
|
Β
|
||||||
|
Γ
|
||||||
|
Δ
|
||||||
|
|
|
||||||
|
ǂ
|
||||||
|
ɛ
|
||||||
|
j
|
||||||
|
̧
|
||||||
|
➢
|
||||||
|
|
||||||
|
̌
|
||||||
|
′
|
||||||
|
«
|
||||||
|
△
|
||||||
|
▲
|
||||||
|
#
|
||||||
|
</b>
|
||||||
|
'
|
||||||
|
Ι
|
||||||
|
+
|
||||||
|
¶
|
||||||
|
/
|
||||||
|
▼
|
||||||
|
⇑
|
||||||
|
□
|
||||||
|
·
|
||||||
|
7
|
||||||
|
▪
|
||||||
|
;
|
||||||
|
?
|
||||||
|
➔
|
||||||
|
∩
|
||||||
|
C
|
||||||
|
÷
|
||||||
|
G
|
||||||
|
⇒
|
||||||
|
K
|
||||||
|
<sup>
|
||||||
|
O
|
||||||
|
S
|
||||||
|
С
|
||||||
|
W
|
||||||
|
Α
|
||||||
|
[
|
||||||
|
○
|
||||||
|
_
|
||||||
|
●
|
||||||
|
‡
|
||||||
|
c
|
||||||
|
z
|
||||||
|
g
|
||||||
|
<i>
|
||||||
|
o
|
||||||
|
<sub>
|
||||||
|
〈
|
||||||
|
〉
|
||||||
|
s
|
||||||
|
⩽
|
||||||
|
w
|
||||||
|
φ
|
||||||
|
ʹ
|
||||||
|
{
|
||||||
|
»
|
||||||
|
∣
|
||||||
|
̆
|
||||||
|
e
|
||||||
|
ˆ
|
||||||
|
∈
|
||||||
|
τ
|
||||||
|
◆
|
||||||
|
ι
|
||||||
|
∅
|
||||||
|
∆
|
||||||
|
∙
|
||||||
|
∘
|
||||||
|
Ø
|
||||||
|
ß
|
||||||
|
✔
|
||||||
|
∞
|
||||||
|
∑
|
||||||
|
−
|
||||||
|
×
|
||||||
|
◊
|
||||||
|
∗
|
||||||
|
∖
|
||||||
|
˃
|
||||||
|
˂
|
||||||
|
∫
|
||||||
|
"
|
||||||
|
i
|
||||||
|
&
|
||||||
|
π
|
||||||
|
↔
|
||||||
|
*
|
||||||
|
∥
|
||||||
|
æ
|
||||||
|
∧
|
||||||
|
.
|
||||||
|
⁄
|
||||||
|
ø
|
||||||
|
Q
|
||||||
|
∼
|
||||||
|
6
|
||||||
|
⁎
|
||||||
|
:
|
||||||
|
★
|
||||||
|
>
|
||||||
|
a
|
||||||
|
B
|
||||||
|
≈
|
||||||
|
F
|
||||||
|
J
|
||||||
|
̄
|
||||||
|
N
|
||||||
|
♯
|
||||||
|
R
|
||||||
|
V
|
||||||
|
<overline>
|
||||||
|
―
|
||||||
|
Z
|
||||||
|
♣
|
||||||
|
^
|
||||||
|
¤
|
||||||
|
¥
|
||||||
|
§
|
||||||
|
<underline>
|
||||||
|
¢
|
||||||
|
£
|
||||||
|
≦
|
||||||
|
|
||||||
|
≤
|
||||||
|
‖
|
||||||
|
Λ
|
||||||
|
©
|
||||||
|
n
|
||||||
|
↓
|
||||||
|
→
|
||||||
|
↑
|
||||||
|
r
|
||||||
|
°
|
||||||
|
±
|
||||||
|
v
|
||||||
|
<b>
|
||||||
|
♂
|
||||||
|
k
|
||||||
|
♀
|
||||||
|
~
|
||||||
|
ᅟ
|
||||||
|
̇
|
||||||
|
@
|
||||||
|
”
|
||||||
|
♦
|
||||||
|
ł
|
||||||
|
®
|
||||||
|
⊕
|
||||||
|
„
|
||||||
|
!
|
||||||
|
</sup>
|
||||||
|
%
|
||||||
|
⇓
|
||||||
|
)
|
||||||
|
-
|
||||||
|
1
|
||||||
|
5
|
||||||
|
9
|
||||||
|
=
|
||||||
|
А
|
||||||
|
A
|
||||||
|
‰
|
||||||
|
⋆
|
||||||
|
Σ
|
||||||
|
E
|
||||||
|
◦
|
||||||
|
I
|
||||||
|
※
|
||||||
|
M
|
||||||
|
m
|
||||||
|
̨
|
||||||
|
⩾
|
||||||
|
†
|
||||||
|
</i>
|
||||||
|
•
|
||||||
|
U
|
||||||
|
Y
|
||||||
|
|
||||||
|
]
|
||||||
|
̸
|
||||||
|
2
|
||||||
|
‐
|
||||||
|
–
|
||||||
|
‒
|
||||||
|
̂
|
||||||
|
—
|
||||||
|
̀
|
||||||
|
́
|
||||||
|
’
|
||||||
|
‘
|
||||||
|
⋮
|
||||||
|
⋯
|
||||||
|
̊
|
||||||
|
“
|
||||||
|
̈
|
||||||
|
≧
|
||||||
|
q
|
||||||
|
u
|
||||||
|
ı
|
||||||
|
y
|
||||||
|
</underline>
|
||||||
|
|
||||||
|
̃
|
||||||
|
}
|
||||||
|
ν
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,82 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def download_with_progressbar(url, save_path):
|
||||||
|
logger = get_logger()
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||||
|
block_size = 1024 # 1 Kibibyte
|
||||||
|
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||||
|
with open(save_path, 'wb') as file:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
file.write(data)
|
||||||
|
progress_bar.close()
|
||||||
|
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||||
|
logger.error("Something went wrong while downloading models")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_download(model_storage_directory, url):
|
||||||
|
# using custom model
|
||||||
|
tar_file_name_list = [
|
||||||
|
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||||
|
]
|
||||||
|
if not os.path.exists(
|
||||||
|
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||||
|
) or not os.path.exists(
|
||||||
|
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||||
|
assert url.endswith('.tar'), 'Only supports tar compressed package'
|
||||||
|
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||||
|
print('download {} to {}'.format(url, tmp_path))
|
||||||
|
os.makedirs(model_storage_directory, exist_ok=True)
|
||||||
|
download_with_progressbar(url, tmp_path)
|
||||||
|
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||||
|
for member in tarObj.getmembers():
|
||||||
|
filename = None
|
||||||
|
for tar_file_name in tar_file_name_list:
|
||||||
|
if tar_file_name in member.name:
|
||||||
|
filename = tar_file_name
|
||||||
|
if filename is None:
|
||||||
|
continue
|
||||||
|
file = tarObj.extractfile(member)
|
||||||
|
with open(
|
||||||
|
os.path.join(model_storage_directory, filename),
|
||||||
|
'wb') as f:
|
||||||
|
f.write(file.read())
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def is_link(s):
|
||||||
|
return s is not None and s.startswith('http')
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
|
||||||
|
url = default_url
|
||||||
|
if model_dir is None or is_link(model_dir):
|
||||||
|
if is_link(model_dir):
|
||||||
|
url = model_dir
|
||||||
|
file_name = url.split('/')[-1][:-4]
|
||||||
|
model_dir = default_model_dir
|
||||||
|
model_dir = os.path.join(model_dir, file_name)
|
||||||
|
return model_dir, url
|
|
@ -0,0 +1,9 @@
|
||||||
|
include LICENSE
|
||||||
|
include README.md
|
||||||
|
|
||||||
|
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||||
|
recursive-include ppocr/data/ *.py
|
||||||
|
recursive-include ppocr/postprocess *.py
|
||||||
|
recursive-include tools/infer *.py
|
||||||
|
recursive-include ppstructure *.py
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# TableStructurer
|
||||||
|
|
||||||
|
1. 代码使用
|
||||||
|
```python
|
||||||
|
import cv2
|
||||||
|
from paddlestructure import PaddleStructure,draw_result
|
||||||
|
|
||||||
|
table_engine = PaddleStructure(
|
||||||
|
output='./output/table',
|
||||||
|
show_log=True)
|
||||||
|
|
||||||
|
img_path = '../doc/table/1.png'
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
result = table_engine(img)
|
||||||
|
for line in result:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
im_show = draw_result(image, result,font_path=font_path)
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save('result.jpg')
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 命令行使用
|
||||||
|
```bash
|
||||||
|
paddlestructure --image_dir=../doc/table/1.png
|
||||||
|
```
|
|
@ -0,0 +1,17 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .paddlestructure import PaddleStructure, draw_result, to_excel
|
||||||
|
|
||||||
|
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..'))
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
from ppstructure.predict_system import OCRSystem, save_res
|
||||||
|
from ppstructure.table.predict_table import to_excel
|
||||||
|
from ppstructure.utility import init_args, draw_result
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||||
|
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
|
||||||
|
|
||||||
|
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
|
||||||
|
|
||||||
|
VERSION = '2.1'
|
||||||
|
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
|
||||||
|
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
|
||||||
|
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(mMain=True):
|
||||||
|
import argparse
|
||||||
|
parser = init_args()
|
||||||
|
parser.add_help = mMain
|
||||||
|
|
||||||
|
for action in parser._actions:
|
||||||
|
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
|
||||||
|
action.default = None
|
||||||
|
if mMain:
|
||||||
|
return parser.parse_args()
|
||||||
|
else:
|
||||||
|
inference_args_dict = {}
|
||||||
|
for action in parser._actions:
|
||||||
|
inference_args_dict[action.dest] = action.default
|
||||||
|
return argparse.Namespace(**inference_args_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleStructure(OCRSystem):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
params = parse_args(mMain=False)
|
||||||
|
params.__dict__.update(**kwargs)
|
||||||
|
if params.show_log:
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
params.use_angle_cls = False
|
||||||
|
# init model dir
|
||||||
|
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||||
|
os.path.join(BASE_DIR, VERSION, 'det'),
|
||||||
|
model_urls['det'])
|
||||||
|
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||||
|
os.path.join(BASE_DIR, VERSION, 'rec'),
|
||||||
|
model_urls['rec'])
|
||||||
|
params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir,
|
||||||
|
os.path.join(BASE_DIR, VERSION, 'structure'),
|
||||||
|
model_urls['structure'])
|
||||||
|
# download model
|
||||||
|
maybe_download(params.det_model_dir, det_url)
|
||||||
|
maybe_download(params.rec_model_dir, rec_url)
|
||||||
|
maybe_download(params.structure_model_dir, structure_url)
|
||||||
|
|
||||||
|
if params.rec_char_dict_path is None:
|
||||||
|
params.rec_char_type = 'EN'
|
||||||
|
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
|
||||||
|
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
|
||||||
|
else:
|
||||||
|
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
|
||||||
|
if params.structure_char_dict_path is None:
|
||||||
|
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
|
||||||
|
params.structure_char_dict_path = str(
|
||||||
|
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||||
|
else:
|
||||||
|
params.structure_char_dict_path = str(
|
||||||
|
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||||
|
|
||||||
|
print(params)
|
||||||
|
super().__init__(params)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if isinstance(img, str):
|
||||||
|
# download net image
|
||||||
|
if img.startswith('http'):
|
||||||
|
download_with_progressbar(img, 'tmp.jpg')
|
||||||
|
img = 'tmp.jpg'
|
||||||
|
image_file = img
|
||||||
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
if not flag:
|
||||||
|
with open(image_file, 'rb') as f:
|
||||||
|
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
|
return None
|
||||||
|
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
res = super().__call__(img)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# for cmd
|
||||||
|
args = parse_args(mMain=True)
|
||||||
|
image_dir = args.image_dir
|
||||||
|
save_folder = args.output
|
||||||
|
if image_dir.startswith('http'):
|
||||||
|
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||||
|
image_file_list = ['tmp.jpg']
|
||||||
|
else:
|
||||||
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
if len(image_file_list) == 0:
|
||||||
|
logger.error('no images find in {}'.format(args.image_dir))
|
||||||
|
return
|
||||||
|
|
||||||
|
structure_engine = PaddleStructure(**(args.__dict__))
|
||||||
|
for img_path in image_file_list:
|
||||||
|
img_name = os.path.basename(img_path).split('.')[0]
|
||||||
|
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||||
|
result = structure_engine(img_path)
|
||||||
|
for item in result:
|
||||||
|
logger.info(item['res'])
|
||||||
|
save_res(result, save_folder, img_name)
|
||||||
|
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
import layoutparser as lp
|
||||||
|
|
||||||
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
from tools.infer.predict_system import TextSystem
|
||||||
|
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||||
|
from ppstructure.utility import parse_args,draw_result
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class OCRSystem(object):
|
||||||
|
def __init__(self, args):
|
||||||
|
args.det_limit_type = 'resize_long'
|
||||||
|
args.drop_score = 0
|
||||||
|
self.text_system = TextSystem(args)
|
||||||
|
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
||||||
|
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
||||||
|
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
||||||
|
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
|
||||||
|
self.use_angle_cls = args.use_angle_cls
|
||||||
|
self.drop_score = args.drop_score
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ori_im = img.copy()
|
||||||
|
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||||
|
res_list = []
|
||||||
|
for region in layout_res:
|
||||||
|
x1, y1, x2, y2 = region.coordinates
|
||||||
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||||
|
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||||
|
if region.type == 'Table':
|
||||||
|
res = self.table_system(roi_img)
|
||||||
|
elif region.type == 'Figure':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||||
|
filter_boxes = [x + [x1, y1] for x in filter_boxes]
|
||||||
|
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
|
||||||
|
|
||||||
|
res = (filter_boxes, filter_rec_res)
|
||||||
|
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
|
||||||
|
return res_list
|
||||||
|
|
||||||
|
def save_res(res, save_folder, img_name):
|
||||||
|
excel_save_folder = os.path.join(save_folder, img_name)
|
||||||
|
os.makedirs(excel_save_folder, exist_ok=True)
|
||||||
|
# save res
|
||||||
|
for region in res:
|
||||||
|
if region['type'] == 'Table':
|
||||||
|
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
||||||
|
to_excel(region['res'], excel_path)
|
||||||
|
elif region['type'] == 'Figure':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
|
||||||
|
for box, rec_res in zip(region['res'][0], region['res'][1]):
|
||||||
|
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
image_file_list = image_file_list
|
||||||
|
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||||
|
save_folder = args.output
|
||||||
|
os.makedirs(save_folder, exist_ok=True)
|
||||||
|
|
||||||
|
structure_sys = OCRSystem(args)
|
||||||
|
img_num = len(image_file_list)
|
||||||
|
for i, image_file in enumerate(image_file_list):
|
||||||
|
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||||
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
img_name = os.path.basename(image_file).split('.')[0]
|
||||||
|
|
||||||
|
if not flag:
|
||||||
|
img = cv2.imread(image_file)
|
||||||
|
if img is None:
|
||||||
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
|
continue
|
||||||
|
starttime = time.time()
|
||||||
|
res = structure_sys(img)
|
||||||
|
save_res(res, save_folder, img_name)
|
||||||
|
draw_img = draw_result(img,res, args.vis_font_path)
|
||||||
|
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
|
||||||
|
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||||
|
elapse = time.time() - starttime
|
||||||
|
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
if args.use_mp:
|
||||||
|
p_list = []
|
||||||
|
total_process_num = args.total_process_num
|
||||||
|
for process_id in range(total_process_num):
|
||||||
|
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||||
|
"--process_id={}".format(process_id),
|
||||||
|
"--use_mp={}".format(False)
|
||||||
|
]
|
||||||
|
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||||
|
p_list.append(p)
|
||||||
|
for p in p_list:
|
||||||
|
p.wait()
|
||||||
|
else:
|
||||||
|
main(args)
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
from io import open
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
with open('../requirements.txt', encoding="utf-8-sig") as f:
|
||||||
|
requirements = f.readlines()
|
||||||
|
requirements.append('tqdm')
|
||||||
|
requirements.append('layoutparser')
|
||||||
|
requirements.append('iopath')
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README_ch.md', encoding="utf-8-sig") as f:
|
||||||
|
README = f.read()
|
||||||
|
return README
|
||||||
|
|
||||||
|
|
||||||
|
shutil.copytree('../ppstructure/table', './ppstructure/table')
|
||||||
|
shutil.copyfile('../ppstructure/predict_system.py', './ppstructure/predict_system.py')
|
||||||
|
shutil.copyfile('../ppstructure/utility.py', './ppstructure/utility.py')
|
||||||
|
shutil.copytree('../ppocr', './ppocr')
|
||||||
|
shutil.copytree('../tools', './tools')
|
||||||
|
shutil.copyfile('../LICENSE', './LICENSE')
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='paddlestructure',
|
||||||
|
packages=['paddlestructure'],
|
||||||
|
package_dir={'paddlestructure': ''},
|
||||||
|
include_package_data=True,
|
||||||
|
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
|
||||||
|
version='1.0',
|
||||||
|
install_requires=requirements,
|
||||||
|
license='Apache License 2.0',
|
||||||
|
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
url='https://github.com/PaddlePaddle/PaddleOCR',
|
||||||
|
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
|
||||||
|
keywords=[
|
||||||
|
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
|
||||||
|
],
|
||||||
|
classifiers=[
|
||||||
|
'Intended Audience :: Developers', 'Operating System :: OS Independent',
|
||||||
|
'Natural Language :: Chinese (Simplified)',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.2',
|
||||||
|
'Programming Language :: Python :: 3.3',
|
||||||
|
'Programming Language :: Python :: 3.4',
|
||||||
|
'Programming Language :: Python :: 3.5',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
|
||||||
|
], )
|
||||||
|
|
||||||
|
shutil.rmtree('ppocr')
|
||||||
|
shutil.rmtree('tools')
|
||||||
|
shutil.rmtree('ppstructure')
|
||||||
|
os.remove('LICENSE')
|
|
@ -0,0 +1,15 @@
|
||||||
|
# 表格结构和内容预测
|
||||||
|
|
||||||
|
先cd到PaddleOCR/ppstructure目录下
|
||||||
|
|
||||||
|
预测
|
||||||
|
```python
|
||||||
|
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
|
||||||
|
```
|
||||||
|
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
|
||||||
|
|
||||||
|
评估
|
||||||
|
|
||||||
|
```python
|
||||||
|
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||||
|
```
|
|
@ -0,0 +1,13 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from ppstructure.table.table_metric import TEDS
|
||||||
|
from ppstructure.table.predict_table import TableSystem
|
||||||
|
from ppstructure.utility import init_args
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = init_args()
|
||||||
|
parser.add_argument("--gt_path", type=str)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def main(gt_path, img_root, args):
|
||||||
|
teds = TEDS(n_jobs=16)
|
||||||
|
|
||||||
|
text_sys = TableSystem(args)
|
||||||
|
jsons_gt = json.load(open(gt_path)) # gt
|
||||||
|
pred_htmls = []
|
||||||
|
gt_htmls = []
|
||||||
|
for img_name in tqdm(jsons_gt):
|
||||||
|
# read image
|
||||||
|
img = cv2.imread(os.path.join(img_root,img_name))
|
||||||
|
pred_html = text_sys(img)
|
||||||
|
pred_htmls.append(pred_html)
|
||||||
|
|
||||||
|
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
|
||||||
|
gt_html, gt = get_gt_html(gt_structures, contents_with_block)
|
||||||
|
gt_htmls.append(gt_html)
|
||||||
|
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
|
||||||
|
print('teds:', sum(scores) / len(scores))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gt_html(gt_structures, contents_with_block):
|
||||||
|
end_html = []
|
||||||
|
td_index = 0
|
||||||
|
for tag in gt_structures:
|
||||||
|
if '</td>' in tag:
|
||||||
|
if contents_with_block[td_index] != []:
|
||||||
|
end_html.extend(contents_with_block[td_index])
|
||||||
|
end_html.append(tag)
|
||||||
|
td_index += 1
|
||||||
|
else:
|
||||||
|
end_html.append(tag)
|
||||||
|
return ''.join(end_html), end_html
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
main(args.gt_path,args.image_dir, args)
|
|
@ -0,0 +1,192 @@
|
||||||
|
import json
|
||||||
|
def distance(box_1, box_2):
|
||||||
|
x1, y1, x2, y2 = box_1
|
||||||
|
x3, y3, x4, y4 = box_2
|
||||||
|
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
|
||||||
|
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||||
|
dis_3 = abs(x4- x2) + abs(y4 - y2)
|
||||||
|
return dis + min(dis_2, dis_3)
|
||||||
|
|
||||||
|
def compute_iou(rec1, rec2):
|
||||||
|
"""
|
||||||
|
computing IoU
|
||||||
|
:param rec1: (y0, x0, y1, x1), which reflects
|
||||||
|
(top, left, bottom, right)
|
||||||
|
:param rec2: (y0, x0, y1, x1)
|
||||||
|
:return: scala value of IoU
|
||||||
|
"""
|
||||||
|
# computing area of each rectangles
|
||||||
|
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||||
|
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||||
|
|
||||||
|
# computing the sum_area
|
||||||
|
sum_area = S_rec1 + S_rec2
|
||||||
|
|
||||||
|
# find the each edge of intersect rectangle
|
||||||
|
left_line = max(rec1[1], rec2[1])
|
||||||
|
right_line = min(rec1[3], rec2[3])
|
||||||
|
top_line = max(rec1[0], rec2[0])
|
||||||
|
bottom_line = min(rec1[2], rec2[2])
|
||||||
|
|
||||||
|
# judge if there is an intersect
|
||||||
|
if left_line >= right_line or top_line >= bottom_line:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
intersect = (right_line - left_line) * (bottom_line - top_line)
|
||||||
|
return (intersect / (sum_area - intersect))*1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def matcher_merge(ocr_bboxes, pred_bboxes):
|
||||||
|
all_dis = []
|
||||||
|
ious = []
|
||||||
|
matched = {}
|
||||||
|
for i, gt_box in enumerate(ocr_bboxes):
|
||||||
|
distances = []
|
||||||
|
for j, pred_box in enumerate(pred_bboxes):
|
||||||
|
# compute l1 distence and IOU between two boxes
|
||||||
|
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
|
||||||
|
sorted_distances = distances.copy()
|
||||||
|
# select nearest cell
|
||||||
|
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
|
||||||
|
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||||
|
matched[distances.index(sorted_distances[0])] = [i]
|
||||||
|
else:
|
||||||
|
matched[distances.index(sorted_distances[0])].append(i)
|
||||||
|
return matched#, sum(ious) / len(ious)
|
||||||
|
|
||||||
|
def complex_num(pred_bboxes):
|
||||||
|
complex_nums = []
|
||||||
|
for bbox in pred_bboxes:
|
||||||
|
distances = []
|
||||||
|
temp_ious = []
|
||||||
|
for pred_bbox in pred_bboxes:
|
||||||
|
if bbox != pred_bbox:
|
||||||
|
distances.append(distance(bbox, pred_bbox))
|
||||||
|
temp_ious.append(compute_iou(bbox, pred_bbox))
|
||||||
|
complex_nums.append(temp_ious[distances.index(min(distances))])
|
||||||
|
return sum(complex_nums) / len(complex_nums)
|
||||||
|
|
||||||
|
def get_rows(pred_bboxes):
|
||||||
|
pre_bbox = pred_bboxes[0]
|
||||||
|
res = []
|
||||||
|
step = 0
|
||||||
|
for i in range(len(pred_bboxes)):
|
||||||
|
bbox = pred_bboxes[i]
|
||||||
|
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
res.append(bbox)
|
||||||
|
step += 1
|
||||||
|
for i in range(step):
|
||||||
|
pred_bboxes.pop(0)
|
||||||
|
return res, pred_bboxes
|
||||||
|
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
||||||
|
ys_1 = []
|
||||||
|
ys_2 = []
|
||||||
|
for box in pred_bboxes:
|
||||||
|
ys_1.append(box[1])
|
||||||
|
ys_2.append(box[3])
|
||||||
|
min_y_1 = sum(ys_1) / len(ys_1)
|
||||||
|
min_y_2 = sum(ys_2) / len(ys_2)
|
||||||
|
re_boxes = []
|
||||||
|
for box in pred_bboxes:
|
||||||
|
box[1] = min_y_1
|
||||||
|
box[3] = min_y_2
|
||||||
|
re_boxes.append(box)
|
||||||
|
return re_boxes
|
||||||
|
|
||||||
|
def matcher_refine_row(gt_bboxes, pred_bboxes):
|
||||||
|
before_refine_pred_bboxes = pred_bboxes.copy()
|
||||||
|
pred_bboxes = []
|
||||||
|
while(len(before_refine_pred_bboxes) != 0):
|
||||||
|
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
|
||||||
|
print(row_bboxes)
|
||||||
|
pred_bboxes.extend(refine_rows(row_bboxes))
|
||||||
|
all_dis = []
|
||||||
|
ious = []
|
||||||
|
matched = {}
|
||||||
|
for i, gt_box in enumerate(gt_bboxes):
|
||||||
|
distances = []
|
||||||
|
#temp_ious = []
|
||||||
|
for j, pred_box in enumerate(pred_bboxes):
|
||||||
|
distances.append(distance(gt_box, pred_box))
|
||||||
|
#temp_ious.append(compute_iou(gt_box, pred_box))
|
||||||
|
#all_dis.append(min(distances))
|
||||||
|
#ious.append(temp_ious[distances.index(min(distances))])
|
||||||
|
if distances.index(min(distances)) not in matched.keys():
|
||||||
|
matched[distances.index(min(distances))] = [i]
|
||||||
|
else:
|
||||||
|
matched[distances.index(min(distances))].append(i)
|
||||||
|
return matched#, sum(ious) / len(ious)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#先挑选出一行,再进行匹配
|
||||||
|
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||||
|
gt_box_index = 0
|
||||||
|
delete_gt_bboxes = gt_bboxes.copy()
|
||||||
|
match_bboxes_ready = []
|
||||||
|
matched = {}
|
||||||
|
while(len(delete_gt_bboxes) != 0):
|
||||||
|
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
|
||||||
|
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
|
||||||
|
if len(pred_bboxes_rows) > 0:
|
||||||
|
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||||
|
print(row_bboxes)
|
||||||
|
for i, gt_box in enumerate(row_bboxes):
|
||||||
|
#print(gt_box)
|
||||||
|
pred_distances = []
|
||||||
|
distances = []
|
||||||
|
for pred_bbox in pred_bboxes:
|
||||||
|
pred_distances.append(distance(gt_box, pred_bbox))
|
||||||
|
for j, pred_box in enumerate(match_bboxes_ready):
|
||||||
|
distances.append(distance(gt_box, pred_box))
|
||||||
|
index = pred_distances.index(min(distances))
|
||||||
|
#print('index', index)
|
||||||
|
if index not in matched.keys():
|
||||||
|
matched[index] = [gt_box_index]
|
||||||
|
else:
|
||||||
|
matched[index].append(gt_box_index)
|
||||||
|
gt_box_index += 1
|
||||||
|
return matched
|
||||||
|
|
||||||
|
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||||
|
'''
|
||||||
|
gt_bboxes: 排序后
|
||||||
|
pred_bboxes:
|
||||||
|
'''
|
||||||
|
pre_bbox = gt_bboxes[0]
|
||||||
|
matched = {}
|
||||||
|
match_bboxes_ready = []
|
||||||
|
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||||
|
for i, gt_box in enumerate(gt_bboxes):
|
||||||
|
|
||||||
|
pred_distances = []
|
||||||
|
for pred_bbox in pred_bboxes:
|
||||||
|
pred_distances.append(distance(gt_box, pred_bbox))
|
||||||
|
distances = []
|
||||||
|
gap_pre = gt_box[1] - pre_bbox[1]
|
||||||
|
gap_pre_1 = gt_box[0] - pre_bbox[2]
|
||||||
|
#print(gap_pre, len(pred_bboxes_rows))
|
||||||
|
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
|
||||||
|
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||||
|
if len(pred_bboxes_rows) == 1:
|
||||||
|
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||||
|
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
|
||||||
|
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||||
|
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
|
||||||
|
break
|
||||||
|
#print(match_bboxes_ready)
|
||||||
|
for j, pred_box in enumerate(match_bboxes_ready):
|
||||||
|
distances.append(distance(gt_box, pred_box))
|
||||||
|
index = pred_distances.index(min(distances))
|
||||||
|
#print(gt_box, index)
|
||||||
|
#match_bboxes_ready.pop(distances.index(min(distances)))
|
||||||
|
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
|
||||||
|
if index not in matched.keys():
|
||||||
|
matched[index] = [i]
|
||||||
|
else:
|
||||||
|
matched[index].append(i)
|
||||||
|
pre_bbox = gt_box
|
||||||
|
return matched
|
|
@ -0,0 +1,141 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
import tools.infer.utility as utility
|
||||||
|
from ppocr.data import create_operators, transform
|
||||||
|
from ppocr.postprocess import build_post_process
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class TableStructurer(object):
|
||||||
|
def __init__(self, args):
|
||||||
|
pre_process_list = [{
|
||||||
|
'ResizeTableImage': {
|
||||||
|
'max_len': args.structure_max_len
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'NormalizeImage': {
|
||||||
|
'std': [0.229, 0.224, 0.225],
|
||||||
|
'mean': [0.485, 0.456, 0.406],
|
||||||
|
'scale': '1./255.',
|
||||||
|
'order': 'hwc'
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'PaddingTableImage': None
|
||||||
|
}, {
|
||||||
|
'ToCHWImage': None
|
||||||
|
}, {
|
||||||
|
'KeepKeys': {
|
||||||
|
'keep_keys': ['image']
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
postprocess_params = {
|
||||||
|
'name': 'TableLabelDecode',
|
||||||
|
"character_type": args.structure_char_type,
|
||||||
|
"character_dict_path": args.structure_char_dict_path,
|
||||||
|
"max_text_length": args.structure_max_text_length,
|
||||||
|
"max_elem_length": args.structure_max_elem_length,
|
||||||
|
"max_cell_num": args.structure_max_cell_num
|
||||||
|
}
|
||||||
|
|
||||||
|
self.preprocess_op = create_operators(pre_process_list)
|
||||||
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
|
self.predictor, self.input_tensor, self.output_tensors = \
|
||||||
|
utility.create_predictor(args, 'structure', logger)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ori_im = img.copy()
|
||||||
|
data = {'image': img}
|
||||||
|
data = transform(data, self.preprocess_op)
|
||||||
|
img = data[0]
|
||||||
|
if img is None:
|
||||||
|
return None, 0
|
||||||
|
img = np.expand_dims(img, axis=0)
|
||||||
|
img = img.copy()
|
||||||
|
starttime = time.time()
|
||||||
|
|
||||||
|
self.input_tensor.copy_from_cpu(img)
|
||||||
|
self.predictor.run()
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in self.output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
preds = {}
|
||||||
|
preds['structure_probs'] = outputs[1]
|
||||||
|
preds['loc_preds'] = outputs[0]
|
||||||
|
|
||||||
|
post_result = self.postprocess_op(preds)
|
||||||
|
|
||||||
|
structure_str_list = post_result['structure_str_list']
|
||||||
|
res_loc = post_result['res_loc']
|
||||||
|
imgh, imgw = ori_im.shape[0:2]
|
||||||
|
res_loc_final = []
|
||||||
|
for rno in range(len(res_loc[0])):
|
||||||
|
x0, y0, x1, y1 = res_loc[0][rno]
|
||||||
|
left = max(int(imgw * x0), 0)
|
||||||
|
top = max(int(imgh * y0), 0)
|
||||||
|
right = min(int(imgw * x1), imgw - 1)
|
||||||
|
bottom = min(int(imgh * y1), imgh - 1)
|
||||||
|
res_loc_final.append([left, top, right, bottom])
|
||||||
|
|
||||||
|
structure_str_list = structure_str_list[0][:-1]
|
||||||
|
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
|
||||||
|
|
||||||
|
elapse = time.time() - starttime
|
||||||
|
return (structure_str_list, res_loc_final), elapse
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
table_structurer = TableStructurer(args)
|
||||||
|
count = 0
|
||||||
|
total_time = 0
|
||||||
|
for image_file in image_file_list:
|
||||||
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
if not flag:
|
||||||
|
img = cv2.imread(image_file)
|
||||||
|
if img is None:
|
||||||
|
logger.info("error in loading image:{}".format(image_file))
|
||||||
|
continue
|
||||||
|
structure_res, elapse = table_structurer(img)
|
||||||
|
|
||||||
|
logger.info("result: {}".format(structure_res))
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
total_time += elapse
|
||||||
|
count += 1
|
||||||
|
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(utility.parse_args())
|
|
@ -0,0 +1,221 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||||
|
import cv2
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import tools.infer.predict_rec as predict_rec
|
||||||
|
import tools.infer.predict_det as predict_det
|
||||||
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
from ppstructure.table.matcher import distance, compute_iou
|
||||||
|
from ppstructure.utility import parse_args
|
||||||
|
import ppstructure.table.predict_structure as predict_strture
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def expand(pix, det_box, shape):
|
||||||
|
x0, y0, x1, y1 = det_box
|
||||||
|
# print(shape)
|
||||||
|
h, w, c = shape
|
||||||
|
tmp_x0 = x0 - pix
|
||||||
|
tmp_x1 = x1 + pix
|
||||||
|
tmp_y0 = y0 - pix
|
||||||
|
tmp_y1 = y1 + pix
|
||||||
|
x0_ = tmp_x0 if tmp_x0 >= 0 else 0
|
||||||
|
x1_ = tmp_x1 if tmp_x1 <= w else w
|
||||||
|
y0_ = tmp_y0 if tmp_y0 >= 0 else 0
|
||||||
|
y1_ = tmp_y1 if tmp_y1 <= h else h
|
||||||
|
return x0_, y0_, x1_, y1_
|
||||||
|
|
||||||
|
|
||||||
|
class TableSystem(object):
|
||||||
|
def __init__(self, args, text_detector=None, text_recognizer=None):
|
||||||
|
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
|
||||||
|
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
|
||||||
|
self.table_structurer = predict_strture.TableStructurer(args)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
ori_im = img.copy()
|
||||||
|
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
|
||||||
|
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
|
||||||
|
dt_boxes = sorted_boxes(dt_boxes)
|
||||||
|
|
||||||
|
r_boxes = []
|
||||||
|
for box in dt_boxes:
|
||||||
|
x_min = box[:, 0].min() - 1
|
||||||
|
x_max = box[:, 0].max() + 1
|
||||||
|
y_min = box[:, 1].min() - 1
|
||||||
|
y_max = box[:, 1].max() + 1
|
||||||
|
box = [x_min, y_min, x_max, y_max]
|
||||||
|
r_boxes.append(box)
|
||||||
|
dt_boxes = np.array(r_boxes)
|
||||||
|
|
||||||
|
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||||
|
len(dt_boxes), elapse))
|
||||||
|
if dt_boxes is None:
|
||||||
|
return None, None
|
||||||
|
img_crop_list = []
|
||||||
|
|
||||||
|
for i in range(len(dt_boxes)):
|
||||||
|
det_box = dt_boxes[i]
|
||||||
|
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
|
||||||
|
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||||
|
img_crop_list.append(text_rect)
|
||||||
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
|
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||||
|
len(rec_res), elapse))
|
||||||
|
|
||||||
|
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||||
|
return pred_html
|
||||||
|
|
||||||
|
def rebuild_table(self, structure_res, dt_boxes, rec_res):
|
||||||
|
pred_structures, pred_bboxes = structure_res
|
||||||
|
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||||
|
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
|
||||||
|
return pred_html, pred
|
||||||
|
|
||||||
|
def match_result(self, dt_boxes, pred_bboxes):
|
||||||
|
matched = {}
|
||||||
|
for i, gt_box in enumerate(dt_boxes):
|
||||||
|
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
|
||||||
|
distances = []
|
||||||
|
for j, pred_box in enumerate(pred_bboxes):
|
||||||
|
distances.append(
|
||||||
|
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
|
||||||
|
sorted_distances = distances.copy()
|
||||||
|
# 根据距离和IOU挑选最"近"的cell
|
||||||
|
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
|
||||||
|
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||||
|
matched[distances.index(sorted_distances[0])] = [i]
|
||||||
|
else:
|
||||||
|
matched[distances.index(sorted_distances[0])].append(i)
|
||||||
|
return matched
|
||||||
|
|
||||||
|
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||||
|
end_html = []
|
||||||
|
td_index = 0
|
||||||
|
for tag in pred_structures:
|
||||||
|
if '</td>' in tag:
|
||||||
|
if td_index in matched_index.keys():
|
||||||
|
b_with = False
|
||||||
|
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
|
||||||
|
b_with = True
|
||||||
|
end_html.extend('<b>')
|
||||||
|
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||||
|
content = ocr_contents[td_index_index][0]
|
||||||
|
if len(matched_index[td_index]) > 1:
|
||||||
|
if len(content) == 0:
|
||||||
|
continue
|
||||||
|
if content[0] == ' ':
|
||||||
|
content = content[1:]
|
||||||
|
if '<b>' in content:
|
||||||
|
content = content[3:]
|
||||||
|
if '</b>' in content:
|
||||||
|
content = content[:-4]
|
||||||
|
if len(content) == 0:
|
||||||
|
continue
|
||||||
|
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
|
||||||
|
content += ' '
|
||||||
|
end_html.extend(content)
|
||||||
|
if b_with:
|
||||||
|
end_html.extend('</b>')
|
||||||
|
|
||||||
|
end_html.append(tag)
|
||||||
|
td_index += 1
|
||||||
|
else:
|
||||||
|
end_html.append(tag)
|
||||||
|
return ''.join(end_html), end_html
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_boxes(dt_boxes):
|
||||||
|
"""
|
||||||
|
Sort text boxes in order from top to bottom, left to right
|
||||||
|
args:
|
||||||
|
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||||
|
return:
|
||||||
|
sorted boxes(array) with shape [4, 2]
|
||||||
|
"""
|
||||||
|
num_boxes = dt_boxes.shape[0]
|
||||||
|
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||||
|
_boxes = list(sorted_boxes)
|
||||||
|
|
||||||
|
for i in range(num_boxes - 1):
|
||||||
|
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||||
|
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||||
|
tmp = _boxes[i]
|
||||||
|
_boxes[i] = _boxes[i + 1]
|
||||||
|
_boxes[i + 1] = tmp
|
||||||
|
return _boxes
|
||||||
|
|
||||||
|
|
||||||
|
def to_excel(html_table, excel_path):
|
||||||
|
from tablepyxl import tablepyxl
|
||||||
|
tablepyxl.document_to_xl(html_table, excel_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||||
|
os.makedirs(args.output, exist_ok=True)
|
||||||
|
|
||||||
|
text_sys = TableSystem(args)
|
||||||
|
img_num = len(image_file_list)
|
||||||
|
for i, image_file in enumerate(image_file_list):
|
||||||
|
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||||
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||||
|
if not flag:
|
||||||
|
img = cv2.imread(image_file)
|
||||||
|
if img is None:
|
||||||
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
|
continue
|
||||||
|
starttime = time.time()
|
||||||
|
pred_html = text_sys(img)
|
||||||
|
|
||||||
|
to_excel(pred_html, excel_path)
|
||||||
|
logger.info('excel saved to {}'.format(excel_path))
|
||||||
|
logger.info(pred_html)
|
||||||
|
elapse = time.time() - starttime
|
||||||
|
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
if args.use_mp:
|
||||||
|
p_list = []
|
||||||
|
total_process_num = args.total_process_num
|
||||||
|
for process_id in range(total_process_num):
|
||||||
|
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||||
|
"--process_id={}".format(process_id),
|
||||||
|
"--use_mp={}".format(False)
|
||||||
|
]
|
||||||
|
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||||
|
p_list.append(p)
|
||||||
|
for p in p_list:
|
||||||
|
p.wait()
|
||||||
|
else:
|
||||||
|
main(args)
|
|
@ -0,0 +1,16 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
__all__ = ['TEDS']
|
||||||
|
from .table_metric import TEDS
|
|
@ -0,0 +1,51 @@
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
|
||||||
|
"""
|
||||||
|
A parallel version of the map function with a progress bar.
|
||||||
|
Args:
|
||||||
|
array (array-like): An array to iterate over.
|
||||||
|
function (function): A python function to apply to the elements of array
|
||||||
|
n_jobs (int, default=16): The number of cores to use
|
||||||
|
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
|
||||||
|
keyword arguments to function
|
||||||
|
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
|
||||||
|
Useful for catching bugs
|
||||||
|
Returns:
|
||||||
|
[function(array[0]), function(array[1]), ...]
|
||||||
|
"""
|
||||||
|
# We run the first few iterations serially to catch bugs
|
||||||
|
if front_num > 0:
|
||||||
|
front = [function(**a) if use_kwargs else function(a)
|
||||||
|
for a in array[:front_num]]
|
||||||
|
else:
|
||||||
|
front = []
|
||||||
|
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
|
||||||
|
if n_jobs == 1:
|
||||||
|
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
|
||||||
|
# Assemble the workers
|
||||||
|
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
|
||||||
|
# Pass the elements of array into function
|
||||||
|
if use_kwargs:
|
||||||
|
futures = [pool.submit(function, **a) for a in array[front_num:]]
|
||||||
|
else:
|
||||||
|
futures = [pool.submit(function, a) for a in array[front_num:]]
|
||||||
|
kwargs = {
|
||||||
|
'total': len(futures),
|
||||||
|
'unit': 'it',
|
||||||
|
'unit_scale': True,
|
||||||
|
'leave': True
|
||||||
|
}
|
||||||
|
# Print out the progress as tasks complete
|
||||||
|
for f in tqdm(as_completed(futures), **kwargs):
|
||||||
|
pass
|
||||||
|
out = []
|
||||||
|
# Get the results from the futures.
|
||||||
|
for i, future in tqdm(enumerate(futures)):
|
||||||
|
try:
|
||||||
|
out.append(future.result())
|
||||||
|
except Exception as e:
|
||||||
|
out.append(e)
|
||||||
|
return front + out
|
|
@ -0,0 +1,247 @@
|
||||||
|
# Copyright 2020 IBM
|
||||||
|
# Author: peter.zhong@au1.ibm.com
|
||||||
|
#
|
||||||
|
# This is free software; you can redistribute it and/or modify
|
||||||
|
# it under the terms of the Apache 2.0 License.
|
||||||
|
#
|
||||||
|
# This software is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# Apache 2.0 License for more details.
|
||||||
|
|
||||||
|
import distance
|
||||||
|
from apted import APTED, Config
|
||||||
|
from apted.helpers import Tree
|
||||||
|
from lxml import etree, html
|
||||||
|
from collections import deque
|
||||||
|
from .parallel import parallel_process
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class TableTree(Tree):
|
||||||
|
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
||||||
|
self.tag = tag
|
||||||
|
self.colspan = colspan
|
||||||
|
self.rowspan = rowspan
|
||||||
|
self.content = content
|
||||||
|
self.children = list(children)
|
||||||
|
|
||||||
|
def bracket(self):
|
||||||
|
"""Show tree using brackets notation"""
|
||||||
|
if self.tag == 'td':
|
||||||
|
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
||||||
|
(self.tag, self.colspan, self.rowspan, self.content)
|
||||||
|
else:
|
||||||
|
result = '"tag": %s' % self.tag
|
||||||
|
for child in self.children:
|
||||||
|
result += child.bracket()
|
||||||
|
return "{{{}}}".format(result)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomConfig(Config):
|
||||||
|
@staticmethod
|
||||||
|
def maximum(*sequences):
|
||||||
|
"""Get maximum possible value
|
||||||
|
"""
|
||||||
|
return max(map(len, sequences))
|
||||||
|
|
||||||
|
def normalized_distance(self, *sequences):
|
||||||
|
"""Get distance from 0 to 1
|
||||||
|
"""
|
||||||
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||||
|
|
||||||
|
def rename(self, node1, node2):
|
||||||
|
"""Compares attributes of trees"""
|
||||||
|
#print(node1.tag)
|
||||||
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||||
|
return 1.
|
||||||
|
if node1.tag == 'td':
|
||||||
|
if node1.content or node2.content:
|
||||||
|
#print(node1.content, )
|
||||||
|
return self.normalized_distance(node1.content, node2.content)
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CustomConfig_del_short(Config):
|
||||||
|
@staticmethod
|
||||||
|
def maximum(*sequences):
|
||||||
|
"""Get maximum possible value
|
||||||
|
"""
|
||||||
|
return max(map(len, sequences))
|
||||||
|
|
||||||
|
def normalized_distance(self, *sequences):
|
||||||
|
"""Get distance from 0 to 1
|
||||||
|
"""
|
||||||
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||||
|
|
||||||
|
def rename(self, node1, node2):
|
||||||
|
"""Compares attributes of trees"""
|
||||||
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||||
|
return 1.
|
||||||
|
if node1.tag == 'td':
|
||||||
|
if node1.content or node2.content:
|
||||||
|
#print('before')
|
||||||
|
#print(node1.content, node2.content)
|
||||||
|
#print('after')
|
||||||
|
node1_content = node1.content
|
||||||
|
node2_content = node2.content
|
||||||
|
if len(node1_content) < 3:
|
||||||
|
node1_content = ['####']
|
||||||
|
if len(node2_content) < 3:
|
||||||
|
node2_content = ['####']
|
||||||
|
return self.normalized_distance(node1_content, node2_content)
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
class CustomConfig_del_block(Config):
|
||||||
|
@staticmethod
|
||||||
|
def maximum(*sequences):
|
||||||
|
"""Get maximum possible value
|
||||||
|
"""
|
||||||
|
return max(map(len, sequences))
|
||||||
|
|
||||||
|
def normalized_distance(self, *sequences):
|
||||||
|
"""Get distance from 0 to 1
|
||||||
|
"""
|
||||||
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||||
|
|
||||||
|
def rename(self, node1, node2):
|
||||||
|
"""Compares attributes of trees"""
|
||||||
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||||
|
return 1.
|
||||||
|
if node1.tag == 'td':
|
||||||
|
if node1.content or node2.content:
|
||||||
|
|
||||||
|
node1_content = node1.content
|
||||||
|
node2_content = node2.content
|
||||||
|
while ' ' in node1_content:
|
||||||
|
print(node1_content.index(' '))
|
||||||
|
node1_content.pop(node1_content.index(' '))
|
||||||
|
while ' ' in node2_content:
|
||||||
|
print(node2_content.index(' '))
|
||||||
|
node2_content.pop(node2_content.index(' '))
|
||||||
|
return self.normalized_distance(node1_content, node2_content)
|
||||||
|
return 0.
|
||||||
|
|
||||||
|
class TEDS(object):
|
||||||
|
''' Tree Edit Distance basead Similarity
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
||||||
|
assert isinstance(n_jobs, int) and (
|
||||||
|
n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
||||||
|
self.structure_only = structure_only
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.ignore_nodes = ignore_nodes
|
||||||
|
self.__tokens__ = []
|
||||||
|
|
||||||
|
def tokenize(self, node):
|
||||||
|
''' Tokenizes table cells
|
||||||
|
'''
|
||||||
|
self.__tokens__.append('<%s>' % node.tag)
|
||||||
|
if node.text is not None:
|
||||||
|
self.__tokens__ += list(node.text)
|
||||||
|
for n in node.getchildren():
|
||||||
|
self.tokenize(n)
|
||||||
|
if node.tag != 'unk':
|
||||||
|
self.__tokens__.append('</%s>' % node.tag)
|
||||||
|
if node.tag != 'td' and node.tail is not None:
|
||||||
|
self.__tokens__ += list(node.tail)
|
||||||
|
|
||||||
|
def load_html_tree(self, node, parent=None):
|
||||||
|
''' Converts HTML tree to the format required by apted
|
||||||
|
'''
|
||||||
|
global __tokens__
|
||||||
|
if node.tag == 'td':
|
||||||
|
if self.structure_only:
|
||||||
|
cell = []
|
||||||
|
else:
|
||||||
|
self.__tokens__ = []
|
||||||
|
self.tokenize(node)
|
||||||
|
cell = self.__tokens__[1:-1].copy()
|
||||||
|
new_node = TableTree(node.tag,
|
||||||
|
int(node.attrib.get('colspan', '1')),
|
||||||
|
int(node.attrib.get('rowspan', '1')),
|
||||||
|
cell, *deque())
|
||||||
|
else:
|
||||||
|
new_node = TableTree(node.tag, None, None, None, *deque())
|
||||||
|
if parent is not None:
|
||||||
|
parent.children.append(new_node)
|
||||||
|
if node.tag != 'td':
|
||||||
|
for n in node.getchildren():
|
||||||
|
self.load_html_tree(n, new_node)
|
||||||
|
if parent is None:
|
||||||
|
return new_node
|
||||||
|
|
||||||
|
def evaluate(self, pred, true):
|
||||||
|
''' Computes TEDS score between the prediction and the ground truth of a
|
||||||
|
given sample
|
||||||
|
'''
|
||||||
|
if (not pred) or (not true):
|
||||||
|
return 0.0
|
||||||
|
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
||||||
|
pred = html.fromstring(pred, parser=parser)
|
||||||
|
true = html.fromstring(true, parser=parser)
|
||||||
|
if pred.xpath('body/table') and true.xpath('body/table'):
|
||||||
|
pred = pred.xpath('body/table')[0]
|
||||||
|
true = true.xpath('body/table')[0]
|
||||||
|
if self.ignore_nodes:
|
||||||
|
etree.strip_tags(pred, *self.ignore_nodes)
|
||||||
|
etree.strip_tags(true, *self.ignore_nodes)
|
||||||
|
n_nodes_pred = len(pred.xpath(".//*"))
|
||||||
|
n_nodes_true = len(true.xpath(".//*"))
|
||||||
|
n_nodes = max(n_nodes_pred, n_nodes_true)
|
||||||
|
tree_pred = self.load_html_tree(pred)
|
||||||
|
tree_true = self.load_html_tree(true)
|
||||||
|
distance = APTED(tree_pred, tree_true,
|
||||||
|
CustomConfig()).compute_edit_distance()
|
||||||
|
return 1.0 - (float(distance) / n_nodes)
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def batch_evaluate(self, pred_json, true_json):
|
||||||
|
''' Computes TEDS score between the prediction and the ground truth of
|
||||||
|
a batch of samples
|
||||||
|
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
||||||
|
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
||||||
|
@output: {'FILENAME': 'TEDS SCORE', ...}
|
||||||
|
'''
|
||||||
|
samples = true_json.keys()
|
||||||
|
if self.n_jobs == 1:
|
||||||
|
scores = [self.evaluate(pred_json.get(
|
||||||
|
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
||||||
|
else:
|
||||||
|
inputs = [{'pred': pred_json.get(
|
||||||
|
filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
||||||
|
scores = parallel_process(
|
||||||
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
||||||
|
scores = dict(zip(samples, scores))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def batch_evaluate_html(self, pred_htmls, true_htmls):
|
||||||
|
''' Computes TEDS score between the prediction and the ground truth of
|
||||||
|
a batch of samples
|
||||||
|
'''
|
||||||
|
if self.n_jobs == 1:
|
||||||
|
scores = [self.evaluate(pred_html, true_html) for (
|
||||||
|
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
||||||
|
else:
|
||||||
|
inputs = [{"pred": pred_html, "true": true_html} for(
|
||||||
|
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
||||||
|
|
||||||
|
scores = parallel_process(
|
||||||
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import json
|
||||||
|
import pprint
|
||||||
|
with open('sample_pred.json') as fp:
|
||||||
|
pred_json = json.load(fp)
|
||||||
|
with open('sample_gt.json') as fp:
|
||||||
|
true_json = json.load(fp)
|
||||||
|
teds = TEDS(n_jobs=4)
|
||||||
|
scores = teds.batch_evaluate(pred_json, true_json)
|
||||||
|
pp = pprint.PrettyPrinter()
|
||||||
|
pp.pprint(scores)
|
|
@ -0,0 +1,13 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -0,0 +1,283 @@
|
||||||
|
# This is where we handle translating css styles into openpyxl styles
|
||||||
|
# and cascading those from parent to child in the dom.
|
||||||
|
|
||||||
|
from openpyxl.cell import cell
|
||||||
|
from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
|
||||||
|
from openpyxl.styles.fills import FILL_SOLID
|
||||||
|
from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
|
||||||
|
from openpyxl.styles.colors import BLACK
|
||||||
|
|
||||||
|
FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
|
||||||
|
|
||||||
|
|
||||||
|
def colormap(color):
|
||||||
|
"""
|
||||||
|
Convenience for looking up known colors
|
||||||
|
"""
|
||||||
|
cmap = {'black': BLACK}
|
||||||
|
return cmap.get(color, color)
|
||||||
|
|
||||||
|
|
||||||
|
def style_string_to_dict(style):
|
||||||
|
"""
|
||||||
|
Convert css style string to a python dictionary
|
||||||
|
"""
|
||||||
|
def clean_split(string, delim):
|
||||||
|
return (s.strip() for s in string.split(delim))
|
||||||
|
styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
|
||||||
|
return dict(styles)
|
||||||
|
|
||||||
|
|
||||||
|
def get_side(style, name):
|
||||||
|
return {'border_style': style.get('border-{}-style'.format(name)),
|
||||||
|
'color': colormap(style.get('border-{}-color'.format(name)))}
|
||||||
|
|
||||||
|
known_styles = {}
|
||||||
|
|
||||||
|
|
||||||
|
def style_dict_to_named_style(style_dict, number_format=None):
|
||||||
|
"""
|
||||||
|
Change css style (stored in a python dictionary) to openpyxl NamedStyle
|
||||||
|
"""
|
||||||
|
|
||||||
|
style_and_format_string = str({
|
||||||
|
'style_dict': style_dict,
|
||||||
|
'parent': style_dict.parent,
|
||||||
|
'number_format': number_format,
|
||||||
|
})
|
||||||
|
|
||||||
|
if style_and_format_string not in known_styles:
|
||||||
|
# Font
|
||||||
|
font = Font(bold=style_dict.get('font-weight') == 'bold',
|
||||||
|
color=style_dict.get_color('color', None),
|
||||||
|
size=style_dict.get('font-size'))
|
||||||
|
|
||||||
|
# Alignment
|
||||||
|
alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
|
||||||
|
vertical=style_dict.get('vertical-align'),
|
||||||
|
wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
|
||||||
|
|
||||||
|
# Fill
|
||||||
|
bg_color = style_dict.get_color('background-color')
|
||||||
|
fg_color = style_dict.get_color('foreground-color', Color())
|
||||||
|
fill_type = style_dict.get('fill-type')
|
||||||
|
if bg_color and bg_color != 'transparent':
|
||||||
|
fill = PatternFill(fill_type=fill_type or FILL_SOLID,
|
||||||
|
start_color=bg_color,
|
||||||
|
end_color=fg_color)
|
||||||
|
else:
|
||||||
|
fill = PatternFill()
|
||||||
|
|
||||||
|
# Border
|
||||||
|
border = Border(left=Side(**get_side(style_dict, 'left')),
|
||||||
|
right=Side(**get_side(style_dict, 'right')),
|
||||||
|
top=Side(**get_side(style_dict, 'top')),
|
||||||
|
bottom=Side(**get_side(style_dict, 'bottom')),
|
||||||
|
diagonal=Side(**get_side(style_dict, 'diagonal')),
|
||||||
|
diagonal_direction=None,
|
||||||
|
outline=Side(**get_side(style_dict, 'outline')),
|
||||||
|
vertical=None,
|
||||||
|
horizontal=None)
|
||||||
|
|
||||||
|
name = 'Style {}'.format(len(known_styles) + 1)
|
||||||
|
|
||||||
|
pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
|
||||||
|
number_format=number_format)
|
||||||
|
|
||||||
|
known_styles[style_and_format_string] = pyxl_style
|
||||||
|
|
||||||
|
return known_styles[style_and_format_string]
|
||||||
|
|
||||||
|
|
||||||
|
class StyleDict(dict):
|
||||||
|
"""
|
||||||
|
It's like a dictionary, but it looks for items in the parent dictionary
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.parent = kwargs.pop('parent', None)
|
||||||
|
super(StyleDict, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
if item in self:
|
||||||
|
return super(StyleDict, self).__getitem__(item)
|
||||||
|
elif self.parent:
|
||||||
|
return self.parent[item]
|
||||||
|
else:
|
||||||
|
raise KeyError('{} not found'.format(item))
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(tuple([(k, self.get(k)) for k in self._keys()]))
|
||||||
|
|
||||||
|
# Yielding the keys avoids creating unnecessary data structures
|
||||||
|
# and happily works with both python2 and python3 where the
|
||||||
|
# .keys() method is a dictionary_view in python3 and a list in python2.
|
||||||
|
def _keys(self):
|
||||||
|
yielded = set()
|
||||||
|
for k in self.keys():
|
||||||
|
yielded.add(k)
|
||||||
|
yield k
|
||||||
|
if self.parent:
|
||||||
|
for k in self.parent._keys():
|
||||||
|
if k not in yielded:
|
||||||
|
yielded.add(k)
|
||||||
|
yield k
|
||||||
|
|
||||||
|
def get(self, k, d=None):
|
||||||
|
try:
|
||||||
|
return self[k]
|
||||||
|
except KeyError:
|
||||||
|
return d
|
||||||
|
|
||||||
|
def get_color(self, k, d=None):
|
||||||
|
"""
|
||||||
|
Strip leading # off colors if necessary
|
||||||
|
"""
|
||||||
|
color = self.get(k, d)
|
||||||
|
if hasattr(color, 'startswith') and color.startswith('#'):
|
||||||
|
color = color[1:]
|
||||||
|
if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
|
||||||
|
color = ''.join(2 * c for c in color)
|
||||||
|
return color
|
||||||
|
|
||||||
|
|
||||||
|
class Element(object):
|
||||||
|
"""
|
||||||
|
Our base class for representing an html element along with a cascading style.
|
||||||
|
The element is created along with a parent so that the StyleDict that we store
|
||||||
|
can point to the parent's StyleDict.
|
||||||
|
"""
|
||||||
|
def __init__(self, element, parent=None):
|
||||||
|
self.element = element
|
||||||
|
self.number_format = None
|
||||||
|
parent_style = parent.style_dict if parent else None
|
||||||
|
self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
|
||||||
|
self._style_cache = None
|
||||||
|
|
||||||
|
def style(self):
|
||||||
|
"""
|
||||||
|
Turn the css styles for this element into an openpyxl NamedStyle.
|
||||||
|
"""
|
||||||
|
if not self._style_cache:
|
||||||
|
self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
|
||||||
|
return self._style_cache
|
||||||
|
|
||||||
|
def get_dimension(self, dimension_key):
|
||||||
|
"""
|
||||||
|
Extracts the dimension from the style dict of the Element and returns it as a float.
|
||||||
|
"""
|
||||||
|
dimension = self.style_dict.get(dimension_key)
|
||||||
|
if dimension:
|
||||||
|
if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
|
||||||
|
dimension = dimension[:-2]
|
||||||
|
dimension = float(dimension)
|
||||||
|
return dimension
|
||||||
|
|
||||||
|
|
||||||
|
class Table(Element):
|
||||||
|
"""
|
||||||
|
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
|
||||||
|
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
|
||||||
|
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
|
||||||
|
"""
|
||||||
|
def __init__(self, table):
|
||||||
|
"""
|
||||||
|
takes an html table object (from lxml)
|
||||||
|
"""
|
||||||
|
super(Table, self).__init__(table)
|
||||||
|
table_head = table.find('thead')
|
||||||
|
self.head = TableHead(table_head, parent=self) if table_head is not None else None
|
||||||
|
table_body = table.find('tbody')
|
||||||
|
self.body = TableBody(table_body if table_body is not None else table, parent=self)
|
||||||
|
|
||||||
|
|
||||||
|
class TableHead(Element):
|
||||||
|
"""
|
||||||
|
This class maps to the `<th>` element of the html table.
|
||||||
|
"""
|
||||||
|
def __init__(self, head, parent=None):
|
||||||
|
super(TableHead, self).__init__(head, parent=parent)
|
||||||
|
self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
|
||||||
|
|
||||||
|
|
||||||
|
class TableBody(Element):
|
||||||
|
"""
|
||||||
|
This class maps to the `<tbody>` element of the html table.
|
||||||
|
"""
|
||||||
|
def __init__(self, body, parent=None):
|
||||||
|
super(TableBody, self).__init__(body, parent=parent)
|
||||||
|
self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
|
||||||
|
|
||||||
|
|
||||||
|
class TableRow(Element):
|
||||||
|
"""
|
||||||
|
This class maps to the `<tr>` element of the html table.
|
||||||
|
"""
|
||||||
|
def __init__(self, tr, parent=None):
|
||||||
|
super(TableRow, self).__init__(tr, parent=parent)
|
||||||
|
self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
|
||||||
|
|
||||||
|
|
||||||
|
def element_to_string(el):
|
||||||
|
return _element_to_string(el).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _element_to_string(el):
|
||||||
|
string = ''
|
||||||
|
|
||||||
|
for x in el.iterchildren():
|
||||||
|
string += '\n' + _element_to_string(x)
|
||||||
|
|
||||||
|
text = el.text.strip() if el.text else ''
|
||||||
|
tail = el.tail.strip() if el.tail else ''
|
||||||
|
|
||||||
|
return text + string + '\n' + tail
|
||||||
|
|
||||||
|
|
||||||
|
class TableCell(Element):
|
||||||
|
"""
|
||||||
|
This class maps to the `<td>` element of the html table.
|
||||||
|
"""
|
||||||
|
CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
|
||||||
|
'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
|
||||||
|
|
||||||
|
def __init__(self, cell, parent=None):
|
||||||
|
super(TableCell, self).__init__(cell, parent=parent)
|
||||||
|
self.value = element_to_string(cell)
|
||||||
|
self.number_format = self.get_number_format()
|
||||||
|
|
||||||
|
def data_type(self):
|
||||||
|
cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
|
||||||
|
if cell_types:
|
||||||
|
if 'TYPE_FORMULA' in cell_types:
|
||||||
|
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
|
||||||
|
cell_type = 'TYPE_FORMULA'
|
||||||
|
elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
|
||||||
|
cell_type = 'TYPE_NUMERIC'
|
||||||
|
else:
|
||||||
|
cell_type = cell_types.pop()
|
||||||
|
else:
|
||||||
|
cell_type = 'TYPE_STRING'
|
||||||
|
return getattr(cell, cell_type)
|
||||||
|
|
||||||
|
def get_number_format(self):
|
||||||
|
if 'TYPE_CURRENCY' in self.element.get('class', '').split():
|
||||||
|
return FORMAT_CURRENCY_USD_SIMPLE
|
||||||
|
if 'TYPE_INTEGER' in self.element.get('class', '').split():
|
||||||
|
return '#,##0'
|
||||||
|
if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
|
||||||
|
return FORMAT_PERCENTAGE
|
||||||
|
if 'TYPE_DATE' in self.element.get('class', '').split():
|
||||||
|
return FORMAT_DATE_MMDDYYYY
|
||||||
|
if self.data_type() == cell.TYPE_NUMERIC:
|
||||||
|
try:
|
||||||
|
int(self.value)
|
||||||
|
except ValueError:
|
||||||
|
return '#,##0.##'
|
||||||
|
else:
|
||||||
|
return '#,##0'
|
||||||
|
|
||||||
|
def format(self, cell):
|
||||||
|
cell.style = self.style()
|
||||||
|
data_type = self.data_type()
|
||||||
|
if data_type:
|
||||||
|
cell.data_type = data_type
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Do imports like python3 so our package works for 2 and 3
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from lxml import html
|
||||||
|
from openpyxl import Workbook
|
||||||
|
from openpyxl.utils import get_column_letter
|
||||||
|
from premailer import Premailer
|
||||||
|
from tablepyxl.style import Table
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_int(s):
|
||||||
|
if s.isdigit():
|
||||||
|
return int(s)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_Tables(doc):
|
||||||
|
tree = html.fromstring(doc)
|
||||||
|
comments = tree.xpath('//comment()')
|
||||||
|
for comment in comments:
|
||||||
|
comment.drop_tag()
|
||||||
|
return [Table(table) for table in tree.xpath('//table')]
|
||||||
|
|
||||||
|
|
||||||
|
def write_rows(worksheet, elem, row, column=1):
|
||||||
|
"""
|
||||||
|
Writes every tr child element of elem to a row in the worksheet
|
||||||
|
returns the next row after all rows are written
|
||||||
|
"""
|
||||||
|
from openpyxl.cell.cell import MergedCell
|
||||||
|
|
||||||
|
initial_column = column
|
||||||
|
for table_row in elem.rows:
|
||||||
|
for table_cell in table_row.cells:
|
||||||
|
cell = worksheet.cell(row=row, column=column)
|
||||||
|
while isinstance(cell, MergedCell):
|
||||||
|
column += 1
|
||||||
|
cell = worksheet.cell(row=row, column=column)
|
||||||
|
|
||||||
|
colspan = string_to_int(table_cell.element.get("colspan", "1"))
|
||||||
|
rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
|
||||||
|
if rowspan > 1 or colspan > 1:
|
||||||
|
worksheet.merge_cells(start_row=row, start_column=column,
|
||||||
|
end_row=row + rowspan - 1, end_column=column + colspan - 1)
|
||||||
|
|
||||||
|
cell.value = table_cell.value
|
||||||
|
table_cell.format(cell)
|
||||||
|
min_width = table_cell.get_dimension('min-width')
|
||||||
|
max_width = table_cell.get_dimension('max-width')
|
||||||
|
|
||||||
|
if colspan == 1:
|
||||||
|
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
|
||||||
|
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
|
||||||
|
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
|
||||||
|
# cell in the same column (i.e. width of A2 = width of A1)
|
||||||
|
width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
|
||||||
|
if max_width and width > max_width:
|
||||||
|
width = max_width
|
||||||
|
elif min_width and width < min_width:
|
||||||
|
width = min_width
|
||||||
|
worksheet.column_dimensions[get_column_letter(column)].width = width
|
||||||
|
column += colspan
|
||||||
|
row += 1
|
||||||
|
column = initial_column
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def table_to_sheet(table, wb):
|
||||||
|
"""
|
||||||
|
Takes a table and workbook and writes the table to a new sheet.
|
||||||
|
The sheet title will be the same as the table attribute name.
|
||||||
|
"""
|
||||||
|
ws = wb.create_sheet(title=table.element.get('name'))
|
||||||
|
insert_table(table, ws, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def document_to_workbook(doc, wb=None, base_url=None):
|
||||||
|
"""
|
||||||
|
Takes a string representation of an html document and writes one sheet for
|
||||||
|
every table in the document.
|
||||||
|
The workbook is returned
|
||||||
|
"""
|
||||||
|
if not wb:
|
||||||
|
wb = Workbook()
|
||||||
|
wb.remove(wb.active)
|
||||||
|
|
||||||
|
inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
|
||||||
|
tables = get_Tables(inline_styles_doc)
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
table_to_sheet(table, wb)
|
||||||
|
|
||||||
|
return wb
|
||||||
|
|
||||||
|
|
||||||
|
def document_to_xl(doc, filename, base_url=None):
|
||||||
|
"""
|
||||||
|
Takes a string representation of an html document and writes one sheet for
|
||||||
|
every table in the document. The workbook is written out to a file called filename
|
||||||
|
"""
|
||||||
|
wb = document_to_workbook(doc, base_url=base_url)
|
||||||
|
wb.save(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_table(table, worksheet, column, row):
|
||||||
|
if table.head:
|
||||||
|
row = write_rows(worksheet, table.head, row, column)
|
||||||
|
if table.body:
|
||||||
|
row = write_rows(worksheet, table.body, row, column)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_table_at_cell(table, cell):
|
||||||
|
"""
|
||||||
|
Inserts a table at the location of an openpyxl Cell object.
|
||||||
|
"""
|
||||||
|
ws = cell.parent
|
||||||
|
column, row = cell.column, cell.row
|
||||||
|
insert_table(table, ws, column, row)
|
|
@ -0,0 +1,59 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
|
||||||
|
|
||||||
|
|
||||||
|
def init_args():
|
||||||
|
parser = infer_args()
|
||||||
|
|
||||||
|
# params for output
|
||||||
|
parser.add_argument("--output", type=str, default='./output/table')
|
||||||
|
# params for table structure
|
||||||
|
parser.add_argument("--structure_max_len", type=int, default=488)
|
||||||
|
parser.add_argument("--structure_max_text_length", type=int, default=100)
|
||||||
|
parser.add_argument("--structure_max_elem_length", type=int, default=800)
|
||||||
|
parser.add_argument("--structure_max_cell_num", type=int, default=500)
|
||||||
|
parser.add_argument("--structure_model_dir", type=str)
|
||||||
|
parser.add_argument("--structure_char_type", type=str, default='en')
|
||||||
|
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
|
||||||
|
|
||||||
|
# params for layout detector
|
||||||
|
parser.add_argument("--layout_model_dir", type=str)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = init_args()
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def draw_result(image, result, font_path):
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
boxes, txts, scores = [], [], []
|
||||||
|
for region in result:
|
||||||
|
if region['type'] == 'Table':
|
||||||
|
pass
|
||||||
|
elif region['type'] == 'Figure':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for box, rec_res in zip(region['res'][0], region['res'][1]):
|
||||||
|
boxes.append(np.array(box).reshape(-1, 2))
|
||||||
|
txts.append(rec_res[0])
|
||||||
|
scores.append(rec_res[1])
|
||||||
|
im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
|
||||||
|
return im_show
|
|
@ -0,0 +1,232 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.inference as paddle_infer
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleInferBenchmark(object):
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
model_info: dict={},
|
||||||
|
data_info: dict={},
|
||||||
|
perf_info: dict={},
|
||||||
|
resource_info: dict={},
|
||||||
|
save_log_path: str="",
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Construct PaddleInferBenchmark Class to format logs.
|
||||||
|
args:
|
||||||
|
config(paddle.inference.Config): paddle inference config
|
||||||
|
model_info(dict): basic model info
|
||||||
|
{'model_name': 'resnet50'
|
||||||
|
'precision': 'fp32'}
|
||||||
|
data_info(dict): input data info
|
||||||
|
{'batch_size': 1
|
||||||
|
'shape': '3,224,224'
|
||||||
|
'data_num': 1000}
|
||||||
|
perf_info(dict): performance result
|
||||||
|
{'preprocess_time_s': 1.0
|
||||||
|
'inference_time_s': 2.0
|
||||||
|
'postprocess_time_s': 1.0
|
||||||
|
'total_time_s': 4.0}
|
||||||
|
resource_info(dict):
|
||||||
|
cpu and gpu resources
|
||||||
|
{'cpu_rss': 100
|
||||||
|
'gpu_rss': 100
|
||||||
|
'gpu_util': 60}
|
||||||
|
"""
|
||||||
|
# PaddleInferBenchmark Log Version
|
||||||
|
self.log_version = 1.0
|
||||||
|
|
||||||
|
# Paddle Version
|
||||||
|
self.paddle_version = paddle.__version__
|
||||||
|
self.paddle_commit = paddle.__git_commit__
|
||||||
|
paddle_infer_info = paddle_infer.get_version()
|
||||||
|
self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
|
||||||
|
|
||||||
|
# model info
|
||||||
|
self.model_info = model_info
|
||||||
|
|
||||||
|
# data info
|
||||||
|
self.data_info = data_info
|
||||||
|
|
||||||
|
# perf info
|
||||||
|
self.perf_info = perf_info
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.model_name = model_info['model_name']
|
||||||
|
self.precision = model_info['precision']
|
||||||
|
|
||||||
|
self.batch_size = data_info['batch_size']
|
||||||
|
self.shape = data_info['shape']
|
||||||
|
self.data_num = data_info['data_num']
|
||||||
|
|
||||||
|
self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4)
|
||||||
|
self.inference_time_s = round(perf_info['inference_time_s'], 4)
|
||||||
|
self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4)
|
||||||
|
self.total_time_s = round(perf_info['total_time_s'], 4)
|
||||||
|
except:
|
||||||
|
self.print_help()
|
||||||
|
raise ValueError(
|
||||||
|
"Set argument wrong, please check input argument and its type")
|
||||||
|
|
||||||
|
# conf info
|
||||||
|
self.config_status = self.parse_config(config)
|
||||||
|
self.save_log_path = save_log_path
|
||||||
|
# mem info
|
||||||
|
if isinstance(resource_info, dict):
|
||||||
|
self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
|
||||||
|
self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
|
||||||
|
self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
|
||||||
|
else:
|
||||||
|
self.cpu_rss_mb = 0
|
||||||
|
self.gpu_rss_mb = 0
|
||||||
|
self.gpu_util = 0
|
||||||
|
|
||||||
|
# init benchmark logger
|
||||||
|
self.benchmark_logger()
|
||||||
|
|
||||||
|
def benchmark_logger(self):
|
||||||
|
"""
|
||||||
|
benchmark logger
|
||||||
|
"""
|
||||||
|
# Init logger
|
||||||
|
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
log_output = f"{self.save_log_path}/{self.model_name}.log"
|
||||||
|
Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True)
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format=FORMAT,
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(
|
||||||
|
filename=log_output, mode='w'),
|
||||||
|
logging.StreamHandler(),
|
||||||
|
])
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.logger.info(
|
||||||
|
f"Paddle Inference benchmark log will be saved to {log_output}")
|
||||||
|
|
||||||
|
def parse_config(self, config) -> dict:
|
||||||
|
"""
|
||||||
|
parse paddle predictor config
|
||||||
|
args:
|
||||||
|
config(paddle.inference.Config): paddle inference config
|
||||||
|
return:
|
||||||
|
config_status(dict): dict style config info
|
||||||
|
"""
|
||||||
|
config_status = {}
|
||||||
|
config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu"
|
||||||
|
config_status['ir_optim'] = config.ir_optim()
|
||||||
|
config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
|
||||||
|
config_status['precision'] = self.precision
|
||||||
|
config_status['enable_mkldnn'] = config.mkldnn_enabled()
|
||||||
|
config_status[
|
||||||
|
'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
|
||||||
|
)
|
||||||
|
return config_status
|
||||||
|
|
||||||
|
def report(self, identifier=None):
|
||||||
|
"""
|
||||||
|
print log report
|
||||||
|
args:
|
||||||
|
identifier(string): identify log
|
||||||
|
"""
|
||||||
|
if identifier:
|
||||||
|
identifier = f"[{identifier}]"
|
||||||
|
else:
|
||||||
|
identifier = ""
|
||||||
|
|
||||||
|
self.logger.info("\n")
|
||||||
|
self.logger.info(
|
||||||
|
"---------------------- Paddle info ----------------------")
|
||||||
|
self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
|
||||||
|
self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
|
||||||
|
self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
|
||||||
|
self.logger.info(f"{identifier} log_api_version: {self.log_version}")
|
||||||
|
self.logger.info(
|
||||||
|
"----------------------- Conf info -----------------------")
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} runtime_device: {self.config_status['runtime_device']}"
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} ir_optim: {self.config_status['ir_optim']}")
|
||||||
|
self.logger.info(f"{identifier} enable_memory_optim: {True}")
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
"----------------------- Model info ----------------------")
|
||||||
|
self.logger.info(f"{identifier} model_name: {self.model_name}")
|
||||||
|
self.logger.info(f"{identifier} precision: {self.precision}")
|
||||||
|
self.logger.info(
|
||||||
|
"----------------------- Data info -----------------------")
|
||||||
|
self.logger.info(f"{identifier} batch_size: {self.batch_size}")
|
||||||
|
self.logger.info(f"{identifier} input_shape: {self.shape}")
|
||||||
|
self.logger.info(f"{identifier} data_num: {self.data_num}")
|
||||||
|
self.logger.info(
|
||||||
|
"----------------------- Perf info -----------------------")
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%"
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} total time spent(s): {self.total_time_s}")
|
||||||
|
self.logger.info(
|
||||||
|
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_help(self):
|
||||||
|
"""
|
||||||
|
print function help
|
||||||
|
"""
|
||||||
|
print("""Usage:
|
||||||
|
==== Print inference benchmark logs. ====
|
||||||
|
config = paddle.inference.Config()
|
||||||
|
model_info = {'model_name': 'resnet50'
|
||||||
|
'precision': 'fp32'}
|
||||||
|
data_info = {'batch_size': 1
|
||||||
|
'shape': '3,224,224'
|
||||||
|
'data_num': 1000}
|
||||||
|
perf_info = {'preprocess_time_s': 1.0
|
||||||
|
'inference_time_s': 2.0
|
||||||
|
'postprocess_time_s': 1.0
|
||||||
|
'total_time_s': 4.0}
|
||||||
|
resource_info = {'cpu_rss_mb': 100
|
||||||
|
'gpu_rss_mb': 100
|
||||||
|
'gpu_util': 60}
|
||||||
|
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
|
||||||
|
log('Test')
|
||||||
|
""")
|
||||||
|
|
||||||
|
def __call__(self, identifier=None):
|
||||||
|
"""
|
||||||
|
__call__
|
||||||
|
args:
|
||||||
|
identifier(string): identify log
|
||||||
|
"""
|
||||||
|
self.report(identifier)
|
|
@ -45,9 +45,11 @@ class TextClassifier(object):
|
||||||
"label_list": args.label_list,
|
"label_list": args.label_list,
|
||||||
}
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors = \
|
self.predictor, self.input_tensor, self.output_tensors, _ = \
|
||||||
utility.create_predictor(args, 'cls', logger)
|
utility.create_predictor(args, 'cls', logger)
|
||||||
|
|
||||||
|
self.cls_times = utility.Timer()
|
||||||
|
|
||||||
def resize_norm_img(self, img):
|
def resize_norm_img(self, img):
|
||||||
imgC, imgH, imgW = self.cls_image_shape
|
imgC, imgH, imgW = self.cls_image_shape
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
|
@ -83,7 +85,9 @@ class TextClassifier(object):
|
||||||
cls_res = [['', 0.0]] * img_num
|
cls_res = [['', 0.0]] * img_num
|
||||||
batch_num = self.cls_batch_num
|
batch_num = self.cls_batch_num
|
||||||
elapse = 0
|
elapse = 0
|
||||||
|
self.cls_times.total_time.start()
|
||||||
for beg_img_no in range(0, img_num, batch_num):
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
|
|
||||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
norm_img_batch = []
|
norm_img_batch = []
|
||||||
max_wh_ratio = 0
|
max_wh_ratio = 0
|
||||||
|
@ -91,6 +95,7 @@ class TextClassifier(object):
|
||||||
h, w = img_list[indices[ino]].shape[0:2]
|
h, w = img_list[indices[ino]].shape[0:2]
|
||||||
wh_ratio = w * 1.0 / h
|
wh_ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
|
self.cls_times.preprocess_time.start()
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
@ -98,11 +103,17 @@ class TextClassifier(object):
|
||||||
norm_img_batch = np.concatenate(norm_img_batch)
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
norm_img_batch = norm_img_batch.copy()
|
norm_img_batch = norm_img_batch.copy()
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
|
self.cls_times.preprocess_time.end()
|
||||||
|
self.cls_times.inference_time.start()
|
||||||
|
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
self.predictor.run()
|
self.predictor.run()
|
||||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||||
|
self.cls_times.inference_time.end()
|
||||||
|
self.cls_times.postprocess_time.start()
|
||||||
self.predictor.try_shrink_memory()
|
self.predictor.try_shrink_memory()
|
||||||
cls_result = self.postprocess_op(prob_out)
|
cls_result = self.postprocess_op(prob_out)
|
||||||
|
self.cls_times.postprocess_time.end()
|
||||||
elapse += time.time() - starttime
|
elapse += time.time() - starttime
|
||||||
for rno in range(len(cls_result)):
|
for rno in range(len(cls_result)):
|
||||||
label, score = cls_result[rno]
|
label, score = cls_result[rno]
|
||||||
|
@ -110,6 +121,9 @@ class TextClassifier(object):
|
||||||
if '180' in label and score > self.cls_thresh:
|
if '180' in label and score > self.cls_thresh:
|
||||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||||
img_list[indices[beg_img_no + rno]], 1)
|
img_list[indices[beg_img_no + rno]], 1)
|
||||||
|
self.cls_times.total_time.end()
|
||||||
|
self.cls_times.img_num += img_num
|
||||||
|
elapse = self.cls_times.total_time.value()
|
||||||
return img_list, cls_res, elapse
|
return img_list, cls_res, elapse
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,8 +155,9 @@ def main(args):
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||||
cls_res[ino]))
|
cls_res[ino]))
|
||||||
logger.info("Total predict time for {} images, cost: {:.3f}".format(
|
logger.info(
|
||||||
len(img_list), predict_time))
|
"The predict time about text angle classify module is as follows: ")
|
||||||
|
text_classifier.cls_times.info(average=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -31,6 +31,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.data import create_operators, transform
|
from ppocr.data import create_operators, transform
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
|
|
||||||
|
import tools.infer.benchmark_utils as benchmark_utils
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +43,7 @@ class TextDetector(object):
|
||||||
pre_process_list = [{
|
pre_process_list = [{
|
||||||
'DetResizeForTest': {
|
'DetResizeForTest': {
|
||||||
'limit_side_len': args.det_limit_side_len,
|
'limit_side_len': args.det_limit_side_len,
|
||||||
'limit_type': args.det_limit_type
|
'limit_type': args.det_limit_type,
|
||||||
}
|
}
|
||||||
}, {
|
}, {
|
||||||
'NormalizeImage': {
|
'NormalizeImage': {
|
||||||
|
@ -95,9 +97,10 @@ class TextDetector(object):
|
||||||
|
|
||||||
self.preprocess_op = create_operators(pre_process_list)
|
self.preprocess_op = create_operators(pre_process_list)
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
|
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
||||||
args, 'det', logger) # paddle.jit.load(args.det_model_dir)
|
args, 'det', logger)
|
||||||
# self.predictor.eval()
|
|
||||||
|
self.det_times = utility.Timer()
|
||||||
|
|
||||||
def order_points_clockwise(self, pts):
|
def order_points_clockwise(self, pts):
|
||||||
"""
|
"""
|
||||||
|
@ -155,6 +158,8 @@ class TextDetector(object):
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
data = {'image': img}
|
data = {'image': img}
|
||||||
|
self.det_times.total_time.start()
|
||||||
|
self.det_times.preprocess_time.start()
|
||||||
data = transform(data, self.preprocess_op)
|
data = transform(data, self.preprocess_op)
|
||||||
img, shape_list = data
|
img, shape_list = data
|
||||||
if img is None:
|
if img is None:
|
||||||
|
@ -162,7 +167,9 @@ class TextDetector(object):
|
||||||
img = np.expand_dims(img, axis=0)
|
img = np.expand_dims(img, axis=0)
|
||||||
shape_list = np.expand_dims(shape_list, axis=0)
|
shape_list = np.expand_dims(shape_list, axis=0)
|
||||||
img = img.copy()
|
img = img.copy()
|
||||||
starttime = time.time()
|
|
||||||
|
self.det_times.preprocess_time.end()
|
||||||
|
self.det_times.inference_time.start()
|
||||||
|
|
||||||
self.input_tensor.copy_from_cpu(img)
|
self.input_tensor.copy_from_cpu(img)
|
||||||
self.predictor.run()
|
self.predictor.run()
|
||||||
|
@ -170,6 +177,7 @@ class TextDetector(object):
|
||||||
for output_tensor in self.output_tensors:
|
for output_tensor in self.output_tensors:
|
||||||
output = output_tensor.copy_to_cpu()
|
output = output_tensor.copy_to_cpu()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
self.det_times.inference_time.end()
|
||||||
|
|
||||||
preds = {}
|
preds = {}
|
||||||
if self.det_algorithm == "EAST":
|
if self.det_algorithm == "EAST":
|
||||||
|
@ -184,6 +192,9 @@ class TextDetector(object):
|
||||||
preds['maps'] = outputs[0]
|
preds['maps'] = outputs[0]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
self.det_times.postprocess_time.start()
|
||||||
|
|
||||||
self.predictor.try_shrink_memory()
|
self.predictor.try_shrink_memory()
|
||||||
post_result = self.postprocess_op(preds, shape_list)
|
post_result = self.postprocess_op(preds, shape_list)
|
||||||
dt_boxes = post_result[0]['points']
|
dt_boxes = post_result[0]['points']
|
||||||
|
@ -191,8 +202,11 @@ class TextDetector(object):
|
||||||
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||||
else:
|
else:
|
||||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||||
elapse = time.time() - starttime
|
|
||||||
return dt_boxes, elapse
|
self.det_times.postprocess_time.end()
|
||||||
|
self.det_times.total_time.end()
|
||||||
|
self.det_times.img_num += 1
|
||||||
|
return dt_boxes, self.det_times.total_time.value()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -202,6 +216,13 @@ if __name__ == "__main__":
|
||||||
count = 0
|
count = 0
|
||||||
total_time = 0
|
total_time = 0
|
||||||
draw_img_save = "./inference_results"
|
draw_img_save = "./inference_results"
|
||||||
|
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||||
|
|
||||||
|
# warmup 10 times
|
||||||
|
fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
|
||||||
|
for i in range(10):
|
||||||
|
dt_boxes, _ = text_detector(fake_img)
|
||||||
|
|
||||||
if not os.path.exists(draw_img_save):
|
if not os.path.exists(draw_img_save):
|
||||||
os.makedirs(draw_img_save)
|
os.makedirs(draw_img_save)
|
||||||
for image_file in image_file_list:
|
for image_file in image_file_list:
|
||||||
|
@ -211,16 +232,56 @@ if __name__ == "__main__":
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.info("error in loading image:{}".format(image_file))
|
logger.info("error in loading image:{}".format(image_file))
|
||||||
continue
|
continue
|
||||||
dt_boxes, elapse = text_detector(img)
|
st = time.time()
|
||||||
|
dt_boxes, _ = text_detector(img)
|
||||||
|
elapse = time.time() - st
|
||||||
if count > 0:
|
if count > 0:
|
||||||
total_time += elapse
|
total_time += elapse
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
cm, gm, gu = utility.get_current_memory_mb(0)
|
||||||
|
cpu_mem += cm
|
||||||
|
gpu_mem += gm
|
||||||
|
gpu_util += gu
|
||||||
|
|
||||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||||
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
||||||
img_name_pure = os.path.split(image_file)[-1]
|
img_name_pure = os.path.split(image_file)[-1]
|
||||||
img_path = os.path.join(draw_img_save,
|
img_path = os.path.join(draw_img_save,
|
||||||
"det_res_{}".format(img_name_pure))
|
"det_res_{}".format(img_name_pure))
|
||||||
cv2.imwrite(img_path, src_im)
|
|
||||||
logger.info("The visualized image saved in {}".format(img_path))
|
logger.info("The visualized image saved in {}".format(img_path))
|
||||||
if count > 1:
|
# print the information about memory and time-spent
|
||||||
logger.info("Avg Time: {}".format(total_time / (count - 1)))
|
if args.benchmark:
|
||||||
|
mems = {
|
||||||
|
'cpu_rss_mb': cpu_mem / count,
|
||||||
|
'gpu_rss_mb': gpu_mem / count,
|
||||||
|
'gpu_util': gpu_util * 100 / count
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mems = None
|
||||||
|
logger.info("The predict time about detection module is as follows: ")
|
||||||
|
det_time_dict = text_detector.det_times.report(average=True)
|
||||||
|
det_model_name = args.det_model_dir
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
# construct log information
|
||||||
|
model_info = {
|
||||||
|
'model_name': args.det_model_dir.split('/')[-1],
|
||||||
|
'precision': args.precision
|
||||||
|
}
|
||||||
|
data_info = {
|
||||||
|
'batch_size': 1,
|
||||||
|
'shape': 'dynamic_shape',
|
||||||
|
'data_num': det_time_dict['img_num']
|
||||||
|
}
|
||||||
|
perf_info = {
|
||||||
|
'preprocess_time_s': det_time_dict['preprocess_time'],
|
||||||
|
'inference_time_s': det_time_dict['inference_time'],
|
||||||
|
'postprocess_time_s': det_time_dict['postprocess_time'],
|
||||||
|
'total_time_s': det_time_dict['total_time']
|
||||||
|
}
|
||||||
|
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||||
|
text_detector.config, model_info, data_info, perf_info, mems)
|
||||||
|
benchmark_log("Det")
|
||||||
|
|
|
@ -28,6 +28,7 @@ import traceback
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
|
import tools.infer.benchmark_utils as benchmark_utils
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
|
@ -41,7 +42,6 @@ class TextRecognizer(object):
|
||||||
self.character_type = args.rec_char_type
|
self.character_type = args.rec_char_type
|
||||||
self.rec_batch_num = args.rec_batch_num
|
self.rec_batch_num = args.rec_batch_num
|
||||||
self.rec_algorithm = args.rec_algorithm
|
self.rec_algorithm = args.rec_algorithm
|
||||||
self.max_text_length = args.max_text_length
|
|
||||||
postprocess_params = {
|
postprocess_params = {
|
||||||
'name': 'CTCLabelDecode',
|
'name': 'CTCLabelDecode',
|
||||||
"character_type": args.rec_char_type,
|
"character_type": args.rec_char_type,
|
||||||
|
@ -63,9 +63,11 @@ class TextRecognizer(object):
|
||||||
"use_space_char": args.use_space_char
|
"use_space_char": args.use_space_char
|
||||||
}
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors = \
|
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
|
|
||||||
|
self.rec_times = utility.Timer()
|
||||||
|
|
||||||
def resize_norm_img(self, img, max_wh_ratio):
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
imgC, imgH, imgW = self.rec_image_shape
|
imgC, imgH, imgW = self.rec_image_shape
|
||||||
assert imgC == img.shape[2]
|
assert imgC == img.shape[2]
|
||||||
|
@ -166,17 +168,15 @@ class TextRecognizer(object):
|
||||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||||
# Sorting can speed up the recognition process
|
# Sorting can speed up the recognition process
|
||||||
indices = np.argsort(np.array(width_list))
|
indices = np.argsort(np.array(width_list))
|
||||||
|
self.rec_times.total_time.start()
|
||||||
# rec_res = []
|
|
||||||
rec_res = [['', 0.0]] * img_num
|
rec_res = [['', 0.0]] * img_num
|
||||||
batch_num = self.rec_batch_num
|
batch_num = self.rec_batch_num
|
||||||
elapse = 0
|
|
||||||
for beg_img_no in range(0, img_num, batch_num):
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
norm_img_batch = []
|
norm_img_batch = []
|
||||||
max_wh_ratio = 0
|
max_wh_ratio = 0
|
||||||
|
self.rec_times.preprocess_time.start()
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
# h, w = img_list[ino].shape[0:2]
|
|
||||||
h, w = img_list[indices[ino]].shape[0:2]
|
h, w = img_list[indices[ino]].shape[0:2]
|
||||||
wh_ratio = w * 1.0 / h
|
wh_ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
|
@ -187,9 +187,8 @@ class TextRecognizer(object):
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
norm_img_batch.append(norm_img)
|
norm_img_batch.append(norm_img)
|
||||||
else:
|
else:
|
||||||
norm_img = self.process_image_srn(img_list[indices[ino]],
|
norm_img = self.process_image_srn(
|
||||||
self.rec_image_shape, 8,
|
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||||
self.max_text_length)
|
|
||||||
encoder_word_pos_list = []
|
encoder_word_pos_list = []
|
||||||
gsrm_word_pos_list = []
|
gsrm_word_pos_list = []
|
||||||
gsrm_slf_attn_bias1_list = []
|
gsrm_slf_attn_bias1_list = []
|
||||||
|
@ -203,7 +202,6 @@ class TextRecognizer(object):
|
||||||
norm_img_batch = norm_img_batch.copy()
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
|
||||||
if self.rec_algorithm == "SRN":
|
if self.rec_algorithm == "SRN":
|
||||||
starttime = time.time()
|
|
||||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||||
|
@ -218,19 +216,23 @@ class TextRecognizer(object):
|
||||||
gsrm_slf_attn_bias1_list,
|
gsrm_slf_attn_bias1_list,
|
||||||
gsrm_slf_attn_bias2_list,
|
gsrm_slf_attn_bias2_list,
|
||||||
]
|
]
|
||||||
|
self.rec_times.preprocess_time.end()
|
||||||
|
self.rec_times.inference_time.start()
|
||||||
input_names = self.predictor.get_input_names()
|
input_names = self.predictor.get_input_names()
|
||||||
for i in range(len(input_names)):
|
for i in range(len(input_names)):
|
||||||
input_tensor = self.predictor.get_input_handle(input_names[
|
input_tensor = self.predictor.get_input_handle(input_names[
|
||||||
i])
|
i])
|
||||||
input_tensor.copy_from_cpu(inputs[i])
|
input_tensor.copy_from_cpu(inputs[i])
|
||||||
self.predictor.run()
|
self.predictor.run()
|
||||||
|
self.rec_times.inference_time.end()
|
||||||
outputs = []
|
outputs = []
|
||||||
for output_tensor in self.output_tensors:
|
for output_tensor in self.output_tensors:
|
||||||
output = output_tensor.copy_to_cpu()
|
output = output_tensor.copy_to_cpu()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
preds = {"predict": outputs[2]}
|
preds = {"predict": outputs[2]}
|
||||||
else:
|
else:
|
||||||
starttime = time.time()
|
self.rec_times.preprocess_time.end()
|
||||||
|
self.rec_times.inference_time.start()
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
self.predictor.run()
|
self.predictor.run()
|
||||||
|
|
||||||
|
@ -239,22 +241,31 @@ class TextRecognizer(object):
|
||||||
output = output_tensor.copy_to_cpu()
|
output = output_tensor.copy_to_cpu()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
preds = outputs[0]
|
preds = outputs[0]
|
||||||
self.predictor.try_shrink_memory()
|
self.rec_times.inference_time.end()
|
||||||
|
self.rec_times.postprocess_time.start()
|
||||||
rec_result = self.postprocess_op(preds)
|
rec_result = self.postprocess_op(preds)
|
||||||
for rno in range(len(rec_result)):
|
for rno in range(len(rec_result)):
|
||||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||||
elapse += time.time() - starttime
|
self.rec_times.postprocess_time.end()
|
||||||
return rec_res, elapse
|
self.rec_times.img_num += int(norm_img_batch.shape[0])
|
||||||
|
self.rec_times.total_time.end()
|
||||||
|
return rec_res, self.rec_times.total_time.value()
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
text_recognizer = TextRecognizer(args)
|
text_recognizer = TextRecognizer(args)
|
||||||
total_run_time = 0.0
|
|
||||||
total_images_num = 0
|
|
||||||
valid_image_file_list = []
|
valid_image_file_list = []
|
||||||
img_list = []
|
img_list = []
|
||||||
for idx, image_file in enumerate(image_file_list):
|
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# warmup 10 times
|
||||||
|
fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32)
|
||||||
|
for i in range(10):
|
||||||
|
dt_boxes, _ = text_recognizer(fake_img)
|
||||||
|
|
||||||
|
for image_file in image_file_list:
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
if not flag:
|
if not flag:
|
||||||
img = cv2.imread(image_file)
|
img = cv2.imread(image_file)
|
||||||
|
@ -263,29 +274,54 @@ def main(args):
|
||||||
continue
|
continue
|
||||||
valid_image_file_list.append(image_file)
|
valid_image_file_list.append(image_file)
|
||||||
img_list.append(img)
|
img_list.append(img)
|
||||||
if len(img_list) >= args.rec_batch_num or idx == len(
|
try:
|
||||||
image_file_list) - 1:
|
rec_res, _ = text_recognizer(img_list)
|
||||||
try:
|
if args.benchmark:
|
||||||
rec_res, predict_time = text_recognizer(img_list)
|
cm, gm, gu = utility.get_current_memory_mb(0)
|
||||||
total_run_time += predict_time
|
cpu_mem += cm
|
||||||
except:
|
gpu_mem += gm
|
||||||
logger.info(traceback.format_exc())
|
gpu_util += gu
|
||||||
logger.info(
|
count += 1
|
||||||
"ERROR!!!! \n"
|
|
||||||
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
except Exception as E:
|
||||||
"If your model has tps module: "
|
logger.info(traceback.format_exc())
|
||||||
"TPS does not support variable shape.\n"
|
logger.info(E)
|
||||||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
|
exit()
|
||||||
)
|
for ino in range(len(img_list)):
|
||||||
exit()
|
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||||
for ino in range(len(img_list)):
|
rec_res[ino]))
|
||||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[
|
if args.benchmark:
|
||||||
ino], rec_res[ino]))
|
mems = {
|
||||||
total_images_num += len(valid_image_file_list)
|
'cpu_rss_mb': cpu_mem / count,
|
||||||
valid_image_file_list = []
|
'gpu_rss_mb': gpu_mem / count,
|
||||||
img_list = []
|
'gpu_util': gpu_util * 100 / count
|
||||||
logger.info("Total predict time for {} images, cost: {:.3f}".format(
|
}
|
||||||
total_images_num, total_run_time))
|
else:
|
||||||
|
mems = None
|
||||||
|
logger.info("The predict time about recognizer module is as follows: ")
|
||||||
|
rec_time_dict = text_recognizer.rec_times.report(average=True)
|
||||||
|
rec_model_name = args.rec_model_dir
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
# construct log information
|
||||||
|
model_info = {
|
||||||
|
'model_name': args.rec_model_dir.split('/')[-1],
|
||||||
|
'precision': args.precision
|
||||||
|
}
|
||||||
|
data_info = {
|
||||||
|
'batch_size': args.rec_batch_num,
|
||||||
|
'shape': 'dynamic_shape',
|
||||||
|
'data_num': rec_time_dict['img_num']
|
||||||
|
}
|
||||||
|
perf_info = {
|
||||||
|
'preprocess_time_s': rec_time_dict['preprocess_time'],
|
||||||
|
'inference_time_s': rec_time_dict['inference_time'],
|
||||||
|
'postprocess_time_s': rec_time_dict['postprocess_time'],
|
||||||
|
'total_time_s': rec_time_dict['total_time']
|
||||||
|
}
|
||||||
|
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||||
|
text_recognizer.config, model_info, data_info, perf_info, mems)
|
||||||
|
benchmark_log("Rec")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
|
||||||
|
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
|
@ -32,8 +31,8 @@ import tools.infer.predict_det as predict_det
|
||||||
import tools.infer.predict_cls as predict_cls
|
import tools.infer.predict_cls as predict_cls
|
||||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from tools.infer.utility import draw_ocr_box_txt
|
from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb
|
||||||
|
import tools.infer.benchmark_utils as benchmark_utils
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +87,8 @@ class TextSystem(object):
|
||||||
def __call__(self, img, cls=True):
|
def __call__(self, img, cls=True):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
dt_boxes, elapse = self.text_detector(img)
|
dt_boxes, elapse = self.text_detector(img)
|
||||||
logger.info("dt_boxes num : {}, elapse : {}".format(
|
|
||||||
|
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||||
len(dt_boxes), elapse))
|
len(dt_boxes), elapse))
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
@ -103,11 +103,11 @@ class TextSystem(object):
|
||||||
if self.use_angle_cls and cls:
|
if self.use_angle_cls and cls:
|
||||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||||
img_crop_list)
|
img_crop_list)
|
||||||
logger.info("cls num : {}, elapse : {}".format(
|
logger.debug("cls num : {}, elapse : {}".format(
|
||||||
len(img_crop_list), elapse))
|
len(img_crop_list), elapse))
|
||||||
|
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
logger.info("rec_res num : {}, elapse : {}".format(
|
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||||
len(rec_res), elapse))
|
len(rec_res), elapse))
|
||||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||||
filter_boxes, filter_rec_res = [], []
|
filter_boxes, filter_rec_res = [], []
|
||||||
|
@ -142,23 +142,34 @@ def sorted_boxes(dt_boxes):
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
|
||||||
text_sys = TextSystem(args)
|
text_sys = TextSystem(args)
|
||||||
is_visualize = True
|
is_visualize = True
|
||||||
font_path = args.vis_font_path
|
font_path = args.vis_font_path
|
||||||
drop_score = args.drop_score
|
drop_score = args.drop_score
|
||||||
for image_file in image_file_list:
|
total_time = 0
|
||||||
|
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||||
|
_st = time.time()
|
||||||
|
count = 0
|
||||||
|
for idx, image_file in enumerate(image_file_list):
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
if not flag:
|
if not flag:
|
||||||
img = cv2.imread(image_file)
|
img = cv2.imread(image_file)
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.info("error in loading image:{}".format(image_file))
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
continue
|
continue
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
dt_boxes, rec_res = text_sys(img)
|
dt_boxes, rec_res = text_sys(img)
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
total_time += elapse
|
||||||
|
if args.benchmark and idx % 20 == 0:
|
||||||
|
cm, gm, gu = get_current_memory_mb(0)
|
||||||
|
cpu_mem += cm
|
||||||
|
gpu_mem += gm
|
||||||
|
gpu_util += gu
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
|
||||||
for text, score in rec_res:
|
for text, score in rec_res:
|
||||||
logger.info("{}, {:.3f}".format(text, score))
|
logger.info("{}, {:.3f}".format(text, score))
|
||||||
|
|
||||||
|
@ -178,26 +189,74 @@ def main(args):
|
||||||
draw_img_save = "./inference_results/"
|
draw_img_save = "./inference_results/"
|
||||||
if not os.path.exists(draw_img_save):
|
if not os.path.exists(draw_img_save):
|
||||||
os.makedirs(draw_img_save)
|
os.makedirs(draw_img_save)
|
||||||
|
if flag:
|
||||||
|
image_file = image_file[:-3] + "png"
|
||||||
cv2.imwrite(
|
cv2.imwrite(
|
||||||
os.path.join(draw_img_save, os.path.basename(image_file)),
|
os.path.join(draw_img_save, os.path.basename(image_file)),
|
||||||
draw_img[:, :, ::-1])
|
draw_img[:, :, ::-1])
|
||||||
logger.info("The visualized image saved in {}".format(
|
logger.info("The visualized image saved in {}".format(
|
||||||
os.path.join(draw_img_save, os.path.basename(image_file))))
|
os.path.join(draw_img_save, os.path.basename(image_file))))
|
||||||
|
|
||||||
|
logger.info("The predict total time is {}".format(time.time() - _st))
|
||||||
|
logger.info("\nThe predict total time is {}".format(total_time))
|
||||||
|
|
||||||
|
img_num = text_sys.text_detector.det_times.img_num
|
||||||
|
if args.benchmark:
|
||||||
|
mems = {
|
||||||
|
'cpu_rss_mb': cpu_mem / count,
|
||||||
|
'gpu_rss_mb': gpu_mem / count,
|
||||||
|
'gpu_util': gpu_util * 100 / count
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mems = None
|
||||||
|
det_time_dict = text_sys.text_detector.det_times.report(average=True)
|
||||||
|
rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
|
||||||
|
det_model_name = args.det_model_dir
|
||||||
|
rec_model_name = args.rec_model_dir
|
||||||
|
|
||||||
|
# construct det log information
|
||||||
|
model_info = {
|
||||||
|
'model_name': args.det_model_dir.split('/')[-1],
|
||||||
|
'precision': args.precision
|
||||||
|
}
|
||||||
|
data_info = {
|
||||||
|
'batch_size': 1,
|
||||||
|
'shape': 'dynamic_shape',
|
||||||
|
'data_num': det_time_dict['img_num']
|
||||||
|
}
|
||||||
|
perf_info = {
|
||||||
|
'preprocess_time_s': det_time_dict['preprocess_time'],
|
||||||
|
'inference_time_s': det_time_dict['inference_time'],
|
||||||
|
'postprocess_time_s': det_time_dict['postprocess_time'],
|
||||||
|
'total_time_s': det_time_dict['total_time']
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||||
|
text_sys.text_detector.config, model_info, data_info, perf_info, mems,
|
||||||
|
args.save_log_path)
|
||||||
|
benchmark_log("Det")
|
||||||
|
|
||||||
|
# construct rec log information
|
||||||
|
model_info = {
|
||||||
|
'model_name': args.rec_model_dir.split('/')[-1],
|
||||||
|
'precision': args.precision
|
||||||
|
}
|
||||||
|
data_info = {
|
||||||
|
'batch_size': args.rec_batch_num,
|
||||||
|
'shape': 'dynamic_shape',
|
||||||
|
'data_num': rec_time_dict['img_num']
|
||||||
|
}
|
||||||
|
perf_info = {
|
||||||
|
'preprocess_time_s': rec_time_dict['preprocess_time'],
|
||||||
|
'inference_time_s': rec_time_dict['inference_time'],
|
||||||
|
'postprocess_time_s': rec_time_dict['postprocess_time'],
|
||||||
|
'total_time_s': rec_time_dict['total_time']
|
||||||
|
}
|
||||||
|
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||||
|
text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
|
||||||
|
args.save_log_path)
|
||||||
|
benchmark_log("Rec")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = utility.parse_args()
|
main(utility.parse_args())
|
||||||
if args.use_mp:
|
|
||||||
p_list = []
|
|
||||||
total_process_num = args.total_process_num
|
|
||||||
for process_id in range(total_process_num):
|
|
||||||
cmd = [sys.executable, "-u"] + sys.argv + [
|
|
||||||
"--process_id={}".format(process_id),
|
|
||||||
"--use_mp={}".format(False)
|
|
||||||
]
|
|
||||||
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
|
||||||
p_list.append(p)
|
|
||||||
for p in p_list:
|
|
||||||
p.wait()
|
|
||||||
else:
|
|
||||||
main(args)
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ def init_args():
|
||||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||||
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
parser.add_argument("--precision", type=str, default="fp32")
|
||||||
parser.add_argument("--gpu_mem", type=int, default=500)
|
parser.add_argument("--gpu_mem", type=int, default=500)
|
||||||
|
|
||||||
# params for text detector
|
# params for text detector
|
||||||
|
@ -109,6 +109,11 @@ def init_args():
|
||||||
parser.add_argument("--use_mp", type=str2bool, default=False)
|
parser.add_argument("--use_mp", type=str2bool, default=False)
|
||||||
parser.add_argument("--total_process_num", type=int, default=1)
|
parser.add_argument("--total_process_num", type=int, default=1)
|
||||||
parser.add_argument("--process_id", type=int, default=0)
|
parser.add_argument("--process_id", type=int, default=0)
|
||||||
|
|
||||||
|
parser.add_argument("--benchmark", type=bool, default=False)
|
||||||
|
parser.add_argument("--save_log_path", type=str, default="./log_output/")
|
||||||
|
|
||||||
|
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -118,6 +123,76 @@ def parse_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class Times(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.time = 0.
|
||||||
|
self.st = 0.
|
||||||
|
self.et = 0.
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.st = time.time()
|
||||||
|
|
||||||
|
def end(self, accumulative=True):
|
||||||
|
self.et = time.time()
|
||||||
|
if accumulative:
|
||||||
|
self.time += self.et - self.st
|
||||||
|
else:
|
||||||
|
self.time = self.et - self.st
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.time = 0.
|
||||||
|
self.st = 0.
|
||||||
|
self.et = 0.
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
return round(self.time, 4)
|
||||||
|
|
||||||
|
|
||||||
|
class Timer(Times):
|
||||||
|
def __init__(self):
|
||||||
|
super(Timer, self).__init__()
|
||||||
|
self.total_time = Times()
|
||||||
|
self.preprocess_time = Times()
|
||||||
|
self.inference_time = Times()
|
||||||
|
self.postprocess_time = Times()
|
||||||
|
self.img_num = 0
|
||||||
|
|
||||||
|
def info(self, average=False):
|
||||||
|
logger.info("----------------------- Perf info -----------------------")
|
||||||
|
logger.info("total_time: {}, img_num: {}".format(self.total_time.value(
|
||||||
|
), self.img_num))
|
||||||
|
preprocess_time = round(self.preprocess_time.value() / self.img_num,
|
||||||
|
4) if average else self.preprocess_time.value()
|
||||||
|
postprocess_time = round(
|
||||||
|
self.postprocess_time.value() / self.img_num,
|
||||||
|
4) if average else self.postprocess_time.value()
|
||||||
|
inference_time = round(self.inference_time.value() / self.img_num,
|
||||||
|
4) if average else self.inference_time.value()
|
||||||
|
|
||||||
|
average_latency = self.total_time.value() / self.img_num
|
||||||
|
logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format(
|
||||||
|
average_latency * 1000, 1 / average_latency))
|
||||||
|
logger.info(
|
||||||
|
"preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}".
|
||||||
|
format(preprocess_time * 1000, inference_time * 1000,
|
||||||
|
postprocess_time * 1000))
|
||||||
|
|
||||||
|
def report(self, average=False):
|
||||||
|
dic = {}
|
||||||
|
dic['preprocess_time'] = round(
|
||||||
|
self.preprocess_time.value() / self.img_num,
|
||||||
|
4) if average else self.preprocess_time.value()
|
||||||
|
dic['postprocess_time'] = round(
|
||||||
|
self.postprocess_time.value() / self.img_num,
|
||||||
|
4) if average else self.postprocess_time.value()
|
||||||
|
dic['inference_time'] = round(
|
||||||
|
self.inference_time.value() / self.img_num,
|
||||||
|
4) if average else self.inference_time.value()
|
||||||
|
dic['img_num'] = self.img_num
|
||||||
|
dic['total_time'] = round(self.total_time.value(), 4)
|
||||||
|
return dic
|
||||||
|
|
||||||
|
|
||||||
def create_predictor(args, mode, logger):
|
def create_predictor(args, mode, logger):
|
||||||
if mode == "det":
|
if mode == "det":
|
||||||
model_dir = args.det_model_dir
|
model_dir = args.det_model_dir
|
||||||
|
@ -125,6 +200,8 @@ def create_predictor(args, mode, logger):
|
||||||
model_dir = args.cls_model_dir
|
model_dir = args.cls_model_dir
|
||||||
elif mode == 'rec':
|
elif mode == 'rec':
|
||||||
model_dir = args.rec_model_dir
|
model_dir = args.rec_model_dir
|
||||||
|
elif mode == 'structure':
|
||||||
|
model_dir = args.structure_model_dir
|
||||||
else:
|
else:
|
||||||
model_dir = args.e2e_model_dir
|
model_dir = args.e2e_model_dir
|
||||||
|
|
||||||
|
@ -142,6 +219,16 @@ def create_predictor(args, mode, logger):
|
||||||
|
|
||||||
config = inference.Config(model_file_path, params_file_path)
|
config = inference.Config(model_file_path, params_file_path)
|
||||||
|
|
||||||
|
if hasattr(args, 'precision'):
|
||||||
|
if args.precision == "fp16" and args.use_tensorrt:
|
||||||
|
precision = inference.PrecisionType.Half
|
||||||
|
elif args.precision == "int8":
|
||||||
|
precision = inference.PrecisionType.Int8
|
||||||
|
else:
|
||||||
|
precision = inference.PrecisionType.Float32
|
||||||
|
else:
|
||||||
|
precision = inference.PrecisionType.Float32
|
||||||
|
|
||||||
if args.use_gpu:
|
if args.use_gpu:
|
||||||
config.enable_use_gpu(args.gpu_mem, 0)
|
config.enable_use_gpu(args.gpu_mem, 0)
|
||||||
if args.use_tensorrt:
|
if args.use_tensorrt:
|
||||||
|
@ -244,7 +331,9 @@ def create_predictor(args, mode, logger):
|
||||||
|
|
||||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||||
config.switch_use_feed_fetch_ops(False)
|
config.switch_use_feed_fetch_ops(False)
|
||||||
|
config.switch_ir_optim(True)
|
||||||
|
if mode == 'structure':
|
||||||
|
config.switch_ir_optim(False)
|
||||||
# create predictor
|
# create predictor
|
||||||
predictor = inference.create_predictor(config)
|
predictor = inference.create_predictor(config)
|
||||||
input_names = predictor.get_input_names()
|
input_names = predictor.get_input_names()
|
||||||
|
@ -255,7 +344,7 @@ def create_predictor(args, mode, logger):
|
||||||
for output_name in output_names:
|
for output_name in output_names:
|
||||||
output_tensor = predictor.get_output_handle(output_name)
|
output_tensor = predictor.get_output_handle(output_name)
|
||||||
output_tensors.append(output_tensor)
|
output_tensors.append(output_tensor)
|
||||||
return predictor, input_tensor, output_tensors
|
return predictor, input_tensor, output_tensors, config
|
||||||
|
|
||||||
|
|
||||||
def draw_e2e_res(dt_boxes, strs, img_path):
|
def draw_e2e_res(dt_boxes, strs, img_path):
|
||||||
|
@ -506,5 +595,30 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_memory_mb(gpu_id=None):
|
||||||
|
"""
|
||||||
|
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
|
||||||
|
And this function Current program is time-consuming.
|
||||||
|
"""
|
||||||
|
import pynvml
|
||||||
|
import psutil
|
||||||
|
import GPUtil
|
||||||
|
pid = os.getpid()
|
||||||
|
p = psutil.Process(pid)
|
||||||
|
info = p.memory_full_info()
|
||||||
|
cpu_mem = info.uss / 1024. / 1024.
|
||||||
|
gpu_mem = 0
|
||||||
|
gpu_percent = 0
|
||||||
|
if gpu_id is not None:
|
||||||
|
GPUs = GPUtil.getGPUs()
|
||||||
|
gpu_load = GPUs[gpu_id].load
|
||||||
|
gpu_percent = gpu_load
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||||
|
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
gpu_mem = meminfo.used / 1024. / 1024.
|
||||||
|
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue