merge dygraph
|
@ -1,7 +1,7 @@
|
|||
include LICENSE.txt
|
||||
include LICENSE
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
|
|
|
@ -52,9 +52,10 @@ Architecture:
|
|||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00001
|
||||
Teacher:
|
||||
pretrained:
|
||||
|
@ -71,9 +72,10 @@ Architecture:
|
|||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00001
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 50
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/table_mv3/
|
||||
save_epoch_step: 5
|
||||
# evaluation is run every 400 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 400]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 100
|
||||
max_elem_length: 500
|
||||
max_cell_num: 500
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
||||
Architecture:
|
||||
model_type: table
|
||||
algorithm: TableAttn
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 1.0
|
||||
model_name: small
|
||||
disable_se: True
|
||||
Head:
|
||||
name: TableAttentionHead
|
||||
hidden_size: 256
|
||||
l2_decay: 0.00001
|
||||
loc_type: 2
|
||||
|
||||
Loss:
|
||||
name: TableAttentionLoss
|
||||
structure_weight: 100.0
|
||||
loc_weight: 10000.0
|
||||
|
||||
PostProcess:
|
||||
name: TableLabelDecode
|
||||
|
||||
Metric:
|
||||
name: TableMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/train/
|
||||
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- TableLabelEncode:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 32
|
||||
drop_last: True
|
||||
num_workers: 1
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/val/
|
||||
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- TableLabelEncode:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 1
|
|
@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
|||
e /= 255.0;
|
||||
}
|
||||
(*im).convertTo(*im, CV_32FC3, e);
|
||||
for (int h = 0; h < im->rows; h++) {
|
||||
for (int w = 0; w < im->cols; w++) {
|
||||
im->at<cv::Vec3f>(h, w)[0] =
|
||||
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
|
||||
im->at<cv::Vec3f>(h, w)[1] =
|
||||
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
|
||||
im->at<cv::Vec3f>(h, w)[2] =
|
||||
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
|
||||
}
|
||||
std::vector<cv::Mat> bgr_channels(3);
|
||||
cv::split(*im, bgr_channels);
|
||||
for (auto i = 0; i < bgr_channels.size(); i++) {
|
||||
bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
|
||||
(0.0 - mean[i]) * scale[i]);
|
||||
}
|
||||
cv::merge(bgr_channels, *im);
|
||||
}
|
||||
|
||||
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
|
|
|
@ -355,3 +355,4 @@ im_show.save('result.jpg')
|
|||
| det | 前向时使用启动检测 | TRUE |
|
||||
| rec | 前向时是否启动识别 | TRUE |
|
||||
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
|
||||
| show_log | 是否打印det和rec等信息 | FALSE |
|
||||
|
|
|
@ -362,3 +362,5 @@ im_show.save('result.jpg')
|
|||
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
|
||||
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
|
||||
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
|
||||
| show_log | Whether to print log in det and rec
|
||||
| FALSE |
|
BIN
doc/joinus.PNG
Before Width: | Height: | Size: 78 KiB After Width: | Height: | Size: 205 KiB |
After Width: | Height: | Size: 263 KiB |
After Width: | Height: | Size: 55 KiB |
After Width: | Height: | Size: 672 KiB |
After Width: | Height: | Size: 116 KiB |
After Width: | Height: | Size: 386 KiB |
After Width: | Height: | Size: 388 KiB |
After Width: | Height: | Size: 26 KiB |
138
paddleocr.py
|
@ -19,17 +19,16 @@ __dir__ = os.path.dirname(__file__)
|
|||
sys.path.append(os.path.join(__dir__, ''))
|
||||
|
||||
import cv2
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.infer import predict_system
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
|
||||
from tools.infer.utility import draw_ocr, init_args, str2bool
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
|
@ -37,84 +36,84 @@ __all__ = ['PaddleOCR']
|
|||
model_urls = {
|
||||
'det': {
|
||||
'ch':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'en':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||
},
|
||||
'rec': {
|
||||
'ch': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
||||
},
|
||||
'en': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/en_dict.txt'
|
||||
},
|
||||
'french': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
||||
},
|
||||
'german': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
||||
},
|
||||
'korean': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
||||
},
|
||||
'japan': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
||||
},
|
||||
'chinese_cht': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
|
||||
},
|
||||
'ta': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/ta_dict.txt'
|
||||
},
|
||||
'te': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/te_dict.txt'
|
||||
},
|
||||
'ka': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/ka_dict.txt'
|
||||
},
|
||||
'latin': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/latin_dict.txt'
|
||||
},
|
||||
'arabic': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
|
||||
},
|
||||
'cyrillic': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
|
||||
},
|
||||
'devanagari': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
|
||||
}
|
||||
},
|
||||
'cls':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
|
@ -123,50 +122,6 @@ SUPPORT_REC_MODEL = ['CRNN']
|
|||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def parse_args(mMain=True):
|
||||
import argparse
|
||||
parser = init_args()
|
||||
|
@ -194,10 +149,12 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
args:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
postprocess_params = parse_args(mMain=False)
|
||||
postprocess_params.__dict__.update(**kwargs)
|
||||
self.use_angle_cls = postprocess_params.use_angle_cls
|
||||
lang = postprocess_params.lang
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
self.use_angle_cls = params.use_angle_cls
|
||||
lang = params.lang
|
||||
latin_lang = [
|
||||
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
||||
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
||||
|
@ -223,46 +180,45 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
lang = "devanagari"
|
||||
assert lang in model_urls[
|
||||
'rec'], 'param lang must in {}, but got {}'.format(
|
||||
model_urls['rec'].keys(), lang)
|
||||
model_urls['rec'].keys(), lang)
|
||||
if lang == "ch":
|
||||
det_lang = "ch"
|
||||
else:
|
||||
det_lang = "en"
|
||||
use_inner_dict = False
|
||||
if postprocess_params.rec_char_dict_path is None:
|
||||
if params.rec_char_dict_path is None:
|
||||
use_inner_dict = True
|
||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
'dict_path']
|
||||
|
||||
# init model dir
|
||||
if postprocess_params.det_model_dir is None:
|
||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'det', det_lang)
|
||||
if postprocess_params.rec_model_dir is None:
|
||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'rec', lang)
|
||||
if postprocess_params.cls_model_dir is None:
|
||||
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
print(postprocess_params)
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'det', det_lang),
|
||||
model_urls['det'][det_lang])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'rec', lang),
|
||||
model_urls['rec'][lang]['url'])
|
||||
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'cls'),
|
||||
model_urls['cls'])
|
||||
# download model
|
||||
maybe_download(postprocess_params.det_model_dir,
|
||||
model_urls['det'][det_lang])
|
||||
maybe_download(postprocess_params.rec_model_dir,
|
||||
model_urls['rec'][lang]['url'])
|
||||
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.cls_model_dir, cls_url)
|
||||
|
||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
if params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
sys.exit(0)
|
||||
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
if params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
||||
sys.exit(0)
|
||||
if use_inner_dict:
|
||||
postprocess_params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
||||
params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / params.rec_char_dict_path)
|
||||
|
||||
print(params)
|
||||
# init det_model and rec_model
|
||||
super().__init__(postprocess_params)
|
||||
super().__init__(params)
|
||||
|
||||
def ocr(self, img, det=True, rec=True, cls=True):
|
||||
"""
|
||||
|
@ -320,7 +276,7 @@ def main():
|
|||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
if image_dir.startswith('http'):
|
||||
if is_link(image_dir):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
|
|
|
@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
|
|||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
|||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
|
|
|
@ -30,6 +30,7 @@ from .label_ops import *
|
|||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
from .gen_table_mask import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GenTableMask(object):
|
||||
""" gen table mask """
|
||||
|
||||
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
||||
self.shrink_h_max = 5
|
||||
self.shrink_w_max = 5
|
||||
self.mask_type = mask_type
|
||||
|
||||
def projection(self, erosion, h, w, spilt_threshold=0):
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
return box_list, projection_map
|
||||
|
||||
def projection_cx(self, box_img):
|
||||
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
||||
h, w = box_gray_img.shape
|
||||
# 灰度图片进行二值化处理
|
||||
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
|
||||
# 纵向腐蚀
|
||||
if h < w:
|
||||
kernel = np.ones((2, 1), np.uint8)
|
||||
erode = cv2.erode(thresh1, kernel, iterations=1)
|
||||
else:
|
||||
erode = thresh1
|
||||
# 水平膨胀
|
||||
kernel = np.ones((1, 5), np.uint8)
|
||||
erosion = cv2.dilate(erode, kernel, iterations=1)
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
spilt_threshold = 0
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
split_bbox_list = []
|
||||
if len(box_list) > 1:
|
||||
for i, (h_start, h_end) in enumerate(box_list):
|
||||
if i == 0:
|
||||
h_start = 0
|
||||
if i == len(box_list):
|
||||
h_end = h
|
||||
word_img = erosion[h_start:h_end + 1, :]
|
||||
word_h, word_w = word_img.shape
|
||||
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
|
||||
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
||||
if h_start > 0:
|
||||
h_start -= 1
|
||||
h_end += 1
|
||||
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
|
||||
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
||||
else:
|
||||
split_bbox_list.append([0, 0, w, h])
|
||||
return split_bbox_list
|
||||
|
||||
def shrink_bbox(self, bbox):
|
||||
left, top, right, bottom = bbox
|
||||
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
||||
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
||||
left_new = left + sh_w
|
||||
right_new = right - sh_w
|
||||
top_new = top + sh_h
|
||||
bottom_new = bottom - sh_h
|
||||
if left_new >= right_new:
|
||||
left_new = left
|
||||
right_new = right
|
||||
if top_new >= bottom_new:
|
||||
top_new = top
|
||||
bottom_new = bottom
|
||||
return [left_new, top_new, right_new, bottom_new]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
cells = data['cells']
|
||||
height, width = img.shape[0:2]
|
||||
if self.mask_type == 1:
|
||||
mask_img = np.zeros((height, width), dtype=np.float32)
|
||||
else:
|
||||
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
left, top, right, bottom = bbox
|
||||
box_img = img[top:bottom, left:right, :].copy()
|
||||
split_bbox_list = self.projection_cx(box_img)
|
||||
for sno in range(len(split_bbox_list)):
|
||||
split_bbox_list[sno][0] += left
|
||||
split_bbox_list[sno][1] += top
|
||||
split_bbox_list[sno][2] += left
|
||||
split_bbox_list[sno][3] += top
|
||||
|
||||
for sno in range(len(split_bbox_list)):
|
||||
left, top, right, bottom = split_bbox_list[sno]
|
||||
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
|
||||
if self.mask_type == 1:
|
||||
mask_img[top:bottom, left:right] = 1.0
|
||||
data['mask_img'] = mask_img
|
||||
else:
|
||||
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
||||
data['image'] = mask_img
|
||||
return data
|
||||
|
||||
class ResizeTableImage(object):
|
||||
def __init__(self, max_len, **kwargs):
|
||||
super(ResizeTableImage, self).__init__()
|
||||
self.max_len = max_len
|
||||
|
||||
def get_img_bbox(self, cells):
|
||||
bbox_list = []
|
||||
if len(cells) == 0:
|
||||
return bbox_list
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
bbox_list.append(bbox)
|
||||
return bbox_list
|
||||
|
||||
def resize_img_table(self, img, bbox_list, max_len):
|
||||
height, width = img.shape[0:2]
|
||||
ratio = max_len / (max(height, width) * 1.0)
|
||||
resize_h = int(height * ratio)
|
||||
resize_w = int(width * ratio)
|
||||
img_new = cv2.resize(img, (resize_w, resize_h))
|
||||
bbox_list_new = []
|
||||
for bno in range(len(bbox_list)):
|
||||
left, top, right, bottom = bbox_list[bno].copy()
|
||||
left = int(left * ratio)
|
||||
top = int(top * ratio)
|
||||
right = int(right * ratio)
|
||||
bottom = int(bottom * ratio)
|
||||
bbox_list_new.append([left, top, right, bottom])
|
||||
return img_new, bbox_list_new
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'cells' not in data:
|
||||
cells = []
|
||||
else:
|
||||
cells = data['cells']
|
||||
bbox_list = self.get_img_bbox(cells)
|
||||
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
|
||||
data['image'] = img_new
|
||||
cell_num = len(cells)
|
||||
bno = 0
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in data['cells'][cno]:
|
||||
data['cells'][cno]['bbox'] = bbox_list_new[bno]
|
||||
bno += 1
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
class PaddingTableImage(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(PaddingTableImage, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
max_len = data['max_len']
|
||||
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
|
||||
height, width = img.shape[0:2]
|
||||
padding_img[0:height, 0:width, :] = img.copy()
|
||||
data['image'] = padding_img
|
||||
return data
|
||||
|
|
@ -351,3 +351,162 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
|||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
class TableLabelEncode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
max_elem_length,
|
||||
max_cell_num,
|
||||
character_dict_path,
|
||||
span_weight = 1.0,
|
||||
**kwargs):
|
||||
self.max_text_length = max_text_length
|
||||
self.max_elem_length = max_elem_length
|
||||
self.max_cell_num = max_cell_num
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
for i, char in enumerate(list_character):
|
||||
self.dict_character[char] = i
|
||||
self.dict_elem = {}
|
||||
for i, elem in enumerate(list_elem):
|
||||
self.dict_elem[elem] = i
|
||||
self.span_weight = span_weight
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
for cno in range(1, 1+character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1+character_num, 1+character_num+elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def get_span_idx_list(self):
|
||||
span_idx_list = []
|
||||
for elem in self.dict_elem:
|
||||
if 'span' in elem:
|
||||
span_idx_list.append(self.dict_elem[elem])
|
||||
return span_idx_list
|
||||
|
||||
def __call__(self, data):
|
||||
cells = data['cells']
|
||||
structure = data['structure']['tokens']
|
||||
structure = self.encode(structure, 'elem')
|
||||
if structure is None:
|
||||
return None
|
||||
elem_num = len(structure)
|
||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
||||
structure = np.array(structure)
|
||||
data['structure'] = structure
|
||||
elem_char_idx1 = self.dict_elem['<td>']
|
||||
elem_char_idx2 = self.dict_elem['<td']
|
||||
span_idx_list = self.get_span_idx_list()
|
||||
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
|
||||
td_idx_list = np.where(td_idx_list)[0]
|
||||
|
||||
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
||||
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
img_height, img_width, img_ch = data['image'].shape
|
||||
if len(span_idx_list) > 0:
|
||||
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
||||
span_weight = min(max(span_weight, 1.0), self.span_weight)
|
||||
for cno in range(len(cells)):
|
||||
if 'bbox' in cells[cno]:
|
||||
bbox = cells[cno]['bbox'].copy()
|
||||
bbox[0] = bbox[0] * 1.0 / img_width
|
||||
bbox[1] = bbox[1] * 1.0 / img_height
|
||||
bbox[2] = bbox[2] * 1.0 / img_width
|
||||
bbox[3] = bbox[3] * 1.0 / img_height
|
||||
td_idx = td_idx_list[cno]
|
||||
bbox_list[td_idx] = bbox
|
||||
bbox_list_mask[td_idx] = 1.0
|
||||
cand_span_idx = td_idx + 1
|
||||
if cand_span_idx < (self.max_elem_length + 2):
|
||||
if structure[cand_span_idx] in span_idx_list:
|
||||
structure_mask[cand_span_idx] = span_weight
|
||||
|
||||
data['bbox_list'] = bbox_list
|
||||
data['bbox_list_mask'] = bbox_list_mask
|
||||
data['structure_mask'] = structure_mask
|
||||
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num, elem_num])
|
||||
return data
|
||||
|
||||
def encode(self, text, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
if char_or_elem == "char":
|
||||
max_len = self.max_text_length
|
||||
current_dict = self.dict_character
|
||||
else:
|
||||
max_len = self.max_elem_length
|
||||
current_dict = self.dict_elem
|
||||
if len(text) > max_len:
|
||||
return None
|
||||
if len(text) == 0:
|
||||
if char_or_elem == "char":
|
||||
return [self.dict_character['space']]
|
||||
else:
|
||||
return None
|
||||
text_list = []
|
||||
for char in text:
|
||||
if char not in current_dict:
|
||||
return None
|
||||
text_list.append(current_dict[char])
|
||||
if len(text_list) == 0:
|
||||
if char_or_elem == "char":
|
||||
return [self.dict_character['space']]
|
||||
else:
|
||||
return None
|
||||
return text_list
|
||||
|
||||
def get_ignored_tokens(self, char_or_elem):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||
if char_or_elem == "char":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict_character[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict_character[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||
% beg_or_end
|
||||
elif char_or_elem == "elem":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict_elem[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict_elem[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
@ -81,7 +81,7 @@ class NormalizeImage(object):
|
|||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ class DetResizeForTest(object):
|
|||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, _ = img.shape
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
|
@ -174,7 +174,7 @@ class DetResizeForTest(object):
|
|||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
else:
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
|
@ -182,6 +182,10 @@ class DetResizeForTest(object):
|
|||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h,w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from paddle.io import Dataset
|
||||
import json
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class PubTabDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PubTabDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
label_file_path = dataset_config.pop('label_file_path')
|
||||
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
self.do_hard_select = False
|
||||
if 'hard_select' in loader_config:
|
||||
self.do_hard_select = loader_config['hard_select']
|
||||
self.hard_prob = loader_config['hard_prob']
|
||||
if self.do_hard_select:
|
||||
self.img_select_prob = self.load_hard_select_prob()
|
||||
self.table_select_type = None
|
||||
if 'table_select_type' in loader_config:
|
||||
self.table_select_type = loader_config['table_select_type']
|
||||
self.table_select_prob = loader_config['table_select_prob']
|
||||
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_path)
|
||||
with open(label_file_path, "rb") as f:
|
||||
self.data_lines = f.readlines()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
data_line = self.data_lines[idx]
|
||||
data_line = data_line.decode('utf-8').strip("\n")
|
||||
info = json.loads(data_line)
|
||||
file_name = info['filename']
|
||||
select_flag = True
|
||||
if self.do_hard_select:
|
||||
prob = self.img_select_prob[file_name]
|
||||
if prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if self.table_select_type:
|
||||
structure = info['html']['structure']['tokens'].copy()
|
||||
structure_str = ''.join(structure)
|
||||
table_type = "simple"
|
||||
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||
table_type = "complex"
|
||||
if table_type == "complex":
|
||||
if self.table_select_prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if select_flag:
|
||||
cells = info['html']['cells'].copy()
|
||||
structure = info['html']['structure'].copy()
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
else:
|
||||
outs = None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
|
@ -38,11 +38,13 @@ from .basic_loss import DistanceLoss
|
|||
# combined loss function
|
||||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss'
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import fluid
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
|
||||
super(TableAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.use_giou = use_giou
|
||||
self.giou_weight = giou_weight
|
||||
|
||||
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||
'''
|
||||
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
|
||||
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
|
||||
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
|
||||
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
|
||||
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
|
||||
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
|
||||
|
||||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
|
||||
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
|
||||
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
|
||||
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
|
||||
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
giou = ious - (enclose - uni) / enclose
|
||||
|
||||
loss = 1 - giou
|
||||
|
||||
if reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
structure_targets = batch[1].astype("int64")
|
||||
structure_targets = structure_targets[:, 1:]
|
||||
if len(batch) == 6:
|
||||
structure_mask = batch[5].astype("int64")
|
||||
structure_mask = structure_mask[:, 1:]
|
||||
structure_mask = paddle.reshape(structure_mask, [-1])
|
||||
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
|
||||
structure_targets = paddle.reshape(structure_targets, [-1])
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
||||
if len(batch) == 6:
|
||||
structure_loss = structure_loss * structure_mask
|
||||
|
||||
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
|
||||
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||
|
||||
loc_preds = predicts['loc_preds']
|
||||
loc_targets = batch[2].astype("float32")
|
||||
loc_targets_mask = batch[4].astype("float32")
|
||||
loc_targets = loc_targets[:, 1:, :]
|
||||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
|
||||
if self.use_giou:
|
||||
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
|
||||
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
|
||||
else:
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
|
|
@ -26,11 +26,11 @@ from .rec_metric import RecMetric
|
|||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
|
||||
from .table_metric import TableMetric
|
||||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
class TableMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred, batch, *args, **kwargs):
|
||||
structure_probs = pred['structure_probs'].numpy()
|
||||
structure_labels = batch[1]
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
structure_probs = np.argmax(structure_probs, axis=2)
|
||||
structure_labels = structure_labels[:, 1:]
|
||||
batch_size = structure_probs.shape[0]
|
||||
for bno in range(batch_size):
|
||||
all_num += 1
|
||||
if (structure_probs[bno] == structure_labels[bno]).all():
|
||||
correct_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {
|
||||
'acc': correct_num * 1.0 / all_num,
|
||||
}
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / self.all_num
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -78,10 +78,7 @@ class BaseModel(nn.Layer):
|
|||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
if data is None:
|
||||
x = self.head(x)
|
||||
else:
|
||||
x = self.head(x, data)
|
||||
x = self.head(x, targets=data)
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
|
|
|
@ -29,6 +29,10 @@ def build_backbone(config, model_type):
|
|||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
elif model_type == "table":
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ['ResNet', 'MobileNetV3']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,287 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
model_name='large',
|
||||
scale=0.5,
|
||||
disable_se=False,
|
||||
**kwargs):
|
||||
"""
|
||||
the MobilenetV3 backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for build network
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.disable_se = disable_se
|
||||
|
||||
if model_name == "large":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', 2],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hardswish', 2],
|
||||
[3, 200, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 480, 112, True, 'hardswish', 1],
|
||||
[3, 672, 112, True, 'hardswish', 1],
|
||||
[5, 672, 160, True, 'hardswish', 2],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 960
|
||||
elif model_name == "small":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 2],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hardswish', 2],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 120, 48, True, 'hardswish', 1],
|
||||
[5, 144, 48, True, 'hardswish', 1],
|
||||
[5, 288, 96, True, 'hardswish', 2],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 576
|
||||
else:
|
||||
raise NotImplementedError("mode[" + model_name +
|
||||
"_model] is not implemented!")
|
||||
|
||||
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
||||
assert scale in supported_scale, \
|
||||
"supported scale are {} but input scale is {}".format(supported_scale, scale)
|
||||
inplanes = 16
|
||||
# conv1
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=make_divisible(inplanes * scale),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
block_list = []
|
||||
i = 0
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in cfg:
|
||||
se = se and not self.disable_se
|
||||
start_idx = 2 if model_name == 'large' else 0
|
||||
if s == 2 and i > start_idx:
|
||||
self.out_channels.append(inplanes)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
block_list = []
|
||||
block_list.append(
|
||||
ResidualUnit(
|
||||
in_channels=inplanes,
|
||||
mid_channels=make_divisible(scale * exp),
|
||||
out_channels=make_divisible(scale * c),
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name="conv" + str(i + 2)))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
block_list.append(
|
||||
ConvBNLayer(
|
||||
in_channels=inplanes,
|
||||
out_channels=make_divisible(scale * cls_ch_squeeze),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last'))
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||
for i, stage in enumerate(self.stages):
|
||||
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
out_list = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
out_list.append(x)
|
||||
return out_list
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_bn_scale"),
|
||||
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
||||
moving_mean_name=name + "_bn_mean",
|
||||
moving_variance_name=name + "_bn_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
x = F.relu(x)
|
||||
elif self.act == "hardswish":
|
||||
x = F.hardswish(x)
|
||||
else:
|
||||
print("The activation function({}) is selected incorrectly.".
|
||||
format(self.act))
|
||||
exit()
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None,
|
||||
name=''):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||
self.if_se = use_se
|
||||
|
||||
self.expand_conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_expand")
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=int((kernel_size - 1) // 2),
|
||||
groups=mid_channels,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_depthwise")
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name=name + "_linear")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
x = self.bottleneck_conv(x)
|
||||
if self.if_se:
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4, name=""):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name + "_2_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||
return inputs * outputs
|
|
@ -0,0 +1,280 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_vd_mode:
|
||||
inputs = self._pool2d_avg(inputs)
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512,
|
||||
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_2")
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_3")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
out = []
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
|
@ -31,8 +31,10 @@ def build_head(config):
|
|||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead']
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
|
|||
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
x = self.pool(x)
|
||||
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||
x = self.fc(x)
|
||||
|
|
|
@ -106,7 +106,7 @@ class DBHead(nn.Layer):
|
|||
def step_function(self, x, y):
|
||||
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
shrink_maps = self.binarize(x)
|
||||
if not self.training:
|
||||
return {'maps': shrink_maps}
|
||||
|
|
|
@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
|
|||
act=None,
|
||||
name="f_geo")
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_det = self.det_conv1(x)
|
||||
f_det = self.det_conv2(f_det)
|
||||
f_score = self.score_conv(f_det)
|
||||
|
|
|
@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
|
|||
self.head1 = SAST_Header1(in_channels)
|
||||
self.head2 = SAST_Header2(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_score, f_border = self.head1(x)
|
||||
f_tvo, f_tco = self.head2(x)
|
||||
|
||||
|
|
|
@ -220,7 +220,7 @@ class PGHead(nn.Layer):
|
|||
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_score = self.conv_f_score1(x)
|
||||
f_score = self.conv_f_score2(f_score)
|
||||
f_score = self.conv_f_score3(f_score)
|
||||
|
|
|
@ -33,19 +33,47 @@ def get_para_bias_attr(l2_decay, k):
|
|||
|
||||
|
||||
class CTCHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.out_channels = out_channels
|
||||
if mid_channels is None:
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
else:
|
||||
weight_attr1, bias_attr1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
weight_attr=weight_attr1,
|
||||
bias_attr=bias_attr1)
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
predicts = self.fc(x)
|
||||
weight_attr2, bias_attr2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=mid_channels)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr2,
|
||||
bias_attr=bias_attr2)
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
predicts = self.fc1(x)
|
||||
predicts = self.fc2(predicts)
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
return predicts
|
||||
|
|
|
@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
|
|||
|
||||
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||
|
||||
def forward(self, inputs, others):
|
||||
def forward(self, inputs, targets=None):
|
||||
others = targets[-4:]
|
||||
encoder_word_pos = others[0]
|
||||
gsrm_word_pos = others[1]
|
||||
gsrm_slf_attn_bias1 = others[2]
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TableAttentionHead(nn.Layer):
|
||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
||||
super(TableAttentionHead, self).__init__()
|
||||
self.input_size = in_channels[-1]
|
||||
self.hidden_size = hidden_size
|
||||
self.elem_num = 30
|
||||
self.max_text_length = 100
|
||||
self.max_elem_length = 500
|
||||
self.max_cell_num = 500
|
||||
|
||||
self.structure_attention_cell = AttentionGRUCell(
|
||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
||||
self.loc_type = loc_type
|
||||
self.in_max_len = in_max_len
|
||||
|
||||
if self.loc_type == 1:
|
||||
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||
else:
|
||||
if self.in_max_len == 640:
|
||||
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
||||
elif self.in_max_len == 800:
|
||||
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
||||
else:
|
||||
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
# if and else branch are both needed when you want to assign a variable
|
||||
# if you modify the var in just one branch, then the modification will not work.
|
||||
fea = inputs[-1]
|
||||
if len(fea.shape) == 3:
|
||||
pass
|
||||
else:
|
||||
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
batch_size = fea.shape[0]
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = []
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(self.max_elem_length+1):
|
||||
elem_onehots = self._char_to_onehot(
|
||||
structure[:, i], onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
structure_probs = None
|
||||
loc_preds = None
|
||||
elem_onehots = None
|
||||
outputs = None
|
||||
alpha = None
|
||||
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||
i = 0
|
||||
while i < max_elem_length+1:
|
||||
elem_onehots = self._char_to_onehot(
|
||||
temp_elem, onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
structure_probs_step = self.structure_generator(outputs)
|
||||
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||
i += 1
|
||||
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
structure_probs = F.softmax(structure_probs)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
||||
|
||||
|
||||
class AttentionGRUCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionGRUCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
return cur_hidden, alpha
|
||||
|
||||
|
||||
class AttentionLSTM(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||
super(AttentionLSTM, self).__init__()
|
||||
self.input_size = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_classes = out_channels
|
||||
|
||||
self.attention_cell = AttentionLSTMCell(
|
||||
in_channels, hidden_size, out_channels, use_gru=False)
|
||||
self.generator = nn.Linear(hidden_size, out_channels)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = inputs.shape[0]
|
||||
num_steps = batch_max_length
|
||||
|
||||
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||
(batch_size, self.hidden_size)))
|
||||
output_hiddens = []
|
||||
|
||||
if targets is not None:
|
||||
for i in range(num_steps):
|
||||
# one-hot vectors for a i-th char
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
probs = self.generator(output)
|
||||
|
||||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets, onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
probs_step = self.generator(hidden[0])
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
if probs is None:
|
||||
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||
else:
|
||||
probs = paddle.concat(
|
||||
[probs, paddle.unsqueeze(
|
||||
probs_step, axis=1)], axis=1)
|
||||
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
|
||||
targets = next_input
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class AttentionLSTMCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionLSTMCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
if not use_gru:
|
||||
self.rnn = nn.LSTMCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
else:
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
return cur_hidden, alpha
|
|
@ -21,7 +21,8 @@ def build_neck(config):
|
|||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
from .pg_fpn import PGFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
||||
from .table_fpn import TableFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class TableFPN(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(TableFPN, self).__init__()
|
||||
self.out_channels = 512
|
||||
weight_attr = paddle.nn.initializer.KaimingUniform()
|
||||
self.in2_conv = nn.Conv2D(
|
||||
in_channels=in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in3_conv = nn.Conv2D(
|
||||
in_channels=in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride = 1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in4_conv = nn.Conv2D(
|
||||
in_channels=in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in5_conv = nn.Conv2D(
|
||||
in_channels=in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p5_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p4_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p3_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p2_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.fuse_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels * 4,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
c2, c3, c4, c5 = x
|
||||
|
||||
in5 = self.in5_conv(c5)
|
||||
in4 = self.in4_conv(c4)
|
||||
in3 = self.in3_conv(c3)
|
||||
in2 = self.in2_conv(c2)
|
||||
|
||||
out4 = in4 + F.upsample(
|
||||
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
|
||||
out3 = in3 + F.upsample(
|
||||
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
|
||||
out2 = in2 + F.upsample(
|
||||
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
|
||||
|
||||
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
|
||||
fuse_conv = self.fuse_conv(fuse) * 0.005
|
||||
return [c5 + fuse_conv]
|
|
@ -230,15 +230,8 @@ class GridGenerator(nn.Layer):
|
|||
def build_inv_delta_C_paddle(self, C):
|
||||
""" Return inv_delta_C which is needed to calculate T """
|
||||
F = self.F
|
||||
hat_C = paddle.zeros((F, F), dtype='float64') # F x F
|
||||
for i in range(0, F):
|
||||
for j in range(i, F):
|
||||
if i == j:
|
||||
hat_C[i, j] = 1
|
||||
else:
|
||||
r = paddle.norm(C[i] - C[j])
|
||||
hat_C[i, j] = r
|
||||
hat_C[j, i] = r
|
||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||
delta_C = paddle.concat( # F+3 x F+3
|
||||
[
|
||||
|
|
|
@ -24,7 +24,8 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
@ -33,7 +34,7 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
|
|||
self.character_str = string.printable[:-6]
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type in support_character_type:
|
||||
self.character_str = ""
|
||||
self.character_str = []
|
||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||
character_type)
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str += line
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str += " "
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
|
||||
else:
|
||||
|
@ -319,3 +319,138 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class TableLabelDecode(object):
|
||||
""" """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
**kwargs):
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
self.dict_idx_character = {}
|
||||
for i, char in enumerate(list_character):
|
||||
self.dict_idx_character[i] = char
|
||||
self.dict_character[char] = i
|
||||
self.dict_elem = {}
|
||||
self.dict_idx_elem = {}
|
||||
for i, elem in enumerate(list_elem):
|
||||
self.dict_idx_elem[i] = elem
|
||||
self.dict_elem[elem] = i
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def __call__(self, preds):
|
||||
structure_probs = preds['structure_probs']
|
||||
loc_preds = preds['loc_preds']
|
||||
if isinstance(structure_probs,paddle.Tensor):
|
||||
structure_probs = structure_probs.numpy()
|
||||
if isinstance(loc_preds,paddle.Tensor):
|
||||
loc_preds = loc_preds.numpy()
|
||||
structure_idx = structure_probs.argmax(axis=2)
|
||||
structure_probs = structure_probs.max(axis=2)
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
||||
structure_probs, 'elem')
|
||||
res_html_code_list = []
|
||||
res_loc_list = []
|
||||
batch_num = len(structure_str)
|
||||
for bno in range(batch_num):
|
||||
res_loc = []
|
||||
for sno in range(len(structure_str[bno])):
|
||||
text = structure_str[bno][sno]
|
||||
if text in ['<td>', '<td']:
|
||||
pos = structure_pos[bno][sno]
|
||||
res_loc.append(loc_preds[bno, pos])
|
||||
res_html_code = ''.join(structure_str[bno])
|
||||
res_loc = np.array(res_loc)
|
||||
res_html_code_list.append(res_html_code)
|
||||
res_loc_list.append(res_loc)
|
||||
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
||||
|
||||
def decode(self, text_index, structure_probs, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
if char_or_elem == "char":
|
||||
current_dict = self.dict_idx_character
|
||||
else:
|
||||
current_dict = self.dict_idx_elem
|
||||
ignored_tokens = self.get_ignored_tokens('elem')
|
||||
beg_idx, end_idx = ignored_tokens
|
||||
|
||||
result_list = []
|
||||
result_pos_list = []
|
||||
result_score_list = []
|
||||
result_elem_idx_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
elem_pos_list = []
|
||||
elem_idx_list = []
|
||||
score_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
tmp_elem_idx = int(text_index[batch_idx][idx])
|
||||
if idx > 0 and tmp_elem_idx == end_idx:
|
||||
break
|
||||
if tmp_elem_idx in ignored_tokens:
|
||||
continue
|
||||
|
||||
char_list.append(current_dict[tmp_elem_idx])
|
||||
elem_pos_list.append(idx)
|
||||
score_list.append(structure_probs[batch_idx, idx])
|
||||
elem_idx_list.append(tmp_elem_idx)
|
||||
result_list.append(char_list)
|
||||
result_pos_list.append(elem_pos_list)
|
||||
result_score_list.append(score_list)
|
||||
result_elem_idx_list.append(elem_idx_list)
|
||||
return result_list, result_pos_list, result_score_list, result_elem_idx_list
|
||||
|
||||
def get_ignored_tokens(self, char_or_elem):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||
if char_or_elem == "char":
|
||||
if beg_or_end == "beg":
|
||||
idx = self.dict_character[self.beg_str]
|
||||
elif beg_or_end == "end":
|
||||
idx = self.dict_character[self.end_str]
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||
% beg_or_end
|
||||
elif char_or_elem == "elem":
|
||||
if beg_or_end == "beg":
|
||||
idx = self.dict_elem[self.beg_str]
|
||||
elif beg_or_end == "end":
|
||||
idx = self.dict_elem[self.end_str]
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
←
|
||||
</overline>
|
||||
☆
|
||||
─
|
||||
α
|
||||
|
||||
|
||||
⋅
|
||||
$
|
||||
ω
|
||||
ψ
|
||||
χ
|
||||
(
|
||||
υ
|
||||
≥
|
||||
σ
|
||||
,
|
||||
ρ
|
||||
ε
|
||||
0
|
||||
■
|
||||
4
|
||||
8
|
||||
✗
|
||||
b
|
||||
<
|
||||
✓
|
||||
Ψ
|
||||
Ω
|
||||
€
|
||||
D
|
||||
3
|
||||
Π
|
||||
H
|
||||
║
|
||||
</strike>
|
||||
L
|
||||
Φ
|
||||
Χ
|
||||
θ
|
||||
P
|
||||
κ
|
||||
λ
|
||||
μ
|
||||
T
|
||||
ξ
|
||||
X
|
||||
β
|
||||
γ
|
||||
δ
|
||||
\
|
||||
ζ
|
||||
η
|
||||
`
|
||||
d
|
||||
<strike>
|
||||
h
|
||||
f
|
||||
l
|
||||
Θ
|
||||
p
|
||||
√
|
||||
t
|
||||
</sub>
|
||||
x
|
||||
Β
|
||||
Γ
|
||||
Δ
|
||||
|
|
||||
ǂ
|
||||
ɛ
|
||||
j
|
||||
̧
|
||||
➢
|
||||
|
||||
̌
|
||||
′
|
||||
«
|
||||
△
|
||||
▲
|
||||
#
|
||||
</b>
|
||||
'
|
||||
Ι
|
||||
+
|
||||
¶
|
||||
/
|
||||
▼
|
||||
⇑
|
||||
□
|
||||
·
|
||||
7
|
||||
▪
|
||||
;
|
||||
?
|
||||
➔
|
||||
∩
|
||||
C
|
||||
÷
|
||||
G
|
||||
⇒
|
||||
K
|
||||
<sup>
|
||||
O
|
||||
S
|
||||
С
|
||||
W
|
||||
Α
|
||||
[
|
||||
○
|
||||
_
|
||||
●
|
||||
‡
|
||||
c
|
||||
z
|
||||
g
|
||||
<i>
|
||||
o
|
||||
<sub>
|
||||
〈
|
||||
〉
|
||||
s
|
||||
⩽
|
||||
w
|
||||
φ
|
||||
ʹ
|
||||
{
|
||||
»
|
||||
∣
|
||||
̆
|
||||
e
|
||||
ˆ
|
||||
∈
|
||||
τ
|
||||
◆
|
||||
ι
|
||||
∅
|
||||
∆
|
||||
∙
|
||||
∘
|
||||
Ø
|
||||
ß
|
||||
✔
|
||||
∞
|
||||
∑
|
||||
−
|
||||
×
|
||||
◊
|
||||
∗
|
||||
∖
|
||||
˃
|
||||
˂
|
||||
∫
|
||||
"
|
||||
i
|
||||
&
|
||||
π
|
||||
↔
|
||||
*
|
||||
∥
|
||||
æ
|
||||
∧
|
||||
.
|
||||
⁄
|
||||
ø
|
||||
Q
|
||||
∼
|
||||
6
|
||||
⁎
|
||||
:
|
||||
★
|
||||
>
|
||||
a
|
||||
B
|
||||
≈
|
||||
F
|
||||
J
|
||||
̄
|
||||
N
|
||||
♯
|
||||
R
|
||||
V
|
||||
<overline>
|
||||
―
|
||||
Z
|
||||
♣
|
||||
^
|
||||
¤
|
||||
¥
|
||||
§
|
||||
<underline>
|
||||
¢
|
||||
£
|
||||
≦
|
||||
|
||||
≤
|
||||
‖
|
||||
Λ
|
||||
©
|
||||
n
|
||||
↓
|
||||
→
|
||||
↑
|
||||
r
|
||||
°
|
||||
±
|
||||
v
|
||||
<b>
|
||||
♂
|
||||
k
|
||||
♀
|
||||
~
|
||||
ᅟ
|
||||
̇
|
||||
@
|
||||
”
|
||||
♦
|
||||
ł
|
||||
®
|
||||
⊕
|
||||
„
|
||||
!
|
||||
</sup>
|
||||
%
|
||||
⇓
|
||||
)
|
||||
-
|
||||
1
|
||||
5
|
||||
9
|
||||
=
|
||||
А
|
||||
A
|
||||
‰
|
||||
⋆
|
||||
Σ
|
||||
E
|
||||
◦
|
||||
I
|
||||
※
|
||||
M
|
||||
m
|
||||
̨
|
||||
⩾
|
||||
†
|
||||
</i>
|
||||
•
|
||||
U
|
||||
Y
|
||||
|
||||
]
|
||||
̸
|
||||
2
|
||||
‐
|
||||
–
|
||||
‒
|
||||
̂
|
||||
—
|
||||
̀
|
||||
́
|
||||
’
|
||||
‘
|
||||
⋮
|
||||
⋯
|
||||
̊
|
||||
“
|
||||
̈
|
||||
≧
|
||||
q
|
||||
u
|
||||
ı
|
||||
y
|
||||
</underline>
|
||||
|
||||
̃
|
||||
}
|
||||
ν
|
|
@ -22,7 +22,7 @@ logger_initialized = {}
|
|||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_logger(name='root', log_file=None, log_level=logging.INFO):
|
||||
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
logger = get_logger()
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
assert url.endswith('.tar'), 'Only supports tar compressed package'
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def is_link(s):
|
||||
return s is not None and s.startswith('http')
|
||||
|
||||
|
||||
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
|
||||
url = default_url
|
||||
if model_dir is None or is_link(model_dir):
|
||||
if is_link(model_dir):
|
||||
url = model_dir
|
||||
file_name = url.split('/')[-1][:-4]
|
||||
model_dir = default_model_dir
|
||||
model_dir = os.path.join(model_dir, file_name)
|
||||
return model_dir, url
|
|
@ -25,7 +25,7 @@ import paddle
|
|||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
|
@ -89,6 +89,34 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||
return best_model_dict
|
||||
|
||||
|
||||
def load_dygraph_params(config, model, logger, optimizer):
|
||||
ckp = config['Global']['checkpoints']
|
||||
if ckp and os.path.exists(ckp):
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
return pre_best_model_dict
|
||||
else:
|
||||
pm = config['Global']['pretrained_model']
|
||||
if pm is None:
|
||||
return {}
|
||||
if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"):
|
||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
||||
return {}
|
||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
||||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
model_path,
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
include LICENSE
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
recursive-include test1 *.py
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .paddlestructure import PaddleStructure, draw_result, to_excel
|
||||
|
||||
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
|
|
@ -0,0 +1,86 @@
|
|||
# PaddleStructure
|
||||
|
||||
install layoutparser
|
||||
```sh
|
||||
wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
pip3 install layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
## 1. Introduction to pipeline
|
||||
|
||||
PaddleStructure is a toolkit for complex layout text OCR, the process is as follows
|
||||
|
||||
![pipeline](../doc/table/pipeline.png)
|
||||
|
||||
In PaddleStructure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, and the OCR process will be carried out according to the category.
|
||||
|
||||
Currently layoutparser will output five categories:
|
||||
1. Text
|
||||
2. Title
|
||||
3. Figure
|
||||
4. List
|
||||
5. Table
|
||||
|
||||
Types 1-4 follow the traditional OCR process, and 5 follow the Table OCR process.
|
||||
|
||||
## 2. LayoutParser
|
||||
|
||||
|
||||
## 3. Table OCR
|
||||
|
||||
[doc](table/README.md)
|
||||
|
||||
## 4. Predictive by inference engine
|
||||
|
||||
Use the following commands to complete the inference
|
||||
```python
|
||||
python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel, and the excel file name will be the coordinates of the table in the image.
|
||||
|
||||
## 5. PaddleStructure whl package introduction
|
||||
|
||||
### 5.1 Use
|
||||
|
||||
5.1.1 Use by code
|
||||
```python
|
||||
import os
|
||||
import cv2
|
||||
from paddlestructure import PaddleStructure,draw_result,save_res
|
||||
|
||||
table_engine = PaddleStructure(show_log=True)
|
||||
|
||||
save_folder = './output/table'
|
||||
img_path = '../doc/table/1.png'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
save_res(result, save_folder,os.path.basename(img_path).split('.')[0])
|
||||
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_result(image, result,font_path=font_path)
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
5.1.2 Use by command line
|
||||
```bash
|
||||
paddlestructure --image_dir=../doc/table/1.png
|
||||
```
|
||||
|
||||
### Parameter Description
|
||||
Most of the parameters are consistent with the paddleocr whl package, see [whl package documentation](../doc/doc_ch/whl.md)
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|------------------------|------------------------------------------------------|------------------|
|
||||
| output | The path where excel and recognition results are saved | ./output/table |
|
||||
| structure_max_len | When the table structure model predicts, the long side of the image is resized | 488 |
|
||||
| structure_model_dir | Table structure inference model path | None |
|
||||
| structure_char_type | Dictionary path used by table structure model | ../ppocr/utils/dict/table_structure_dict.tx |
|
||||
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
# PaddleStructure
|
||||
|
||||
安装layoutparser
|
||||
```sh
|
||||
wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
pip3 install layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
## 1. pipeline介绍
|
||||
|
||||
PaddleStructure 是一个用于复杂板式文字OCR的工具包,流程如下
|
||||
![pipeline](../doc/table/pipeline.png)
|
||||
|
||||
在PaddleStructure中,图片会先经由layoutparser进行版面分析,在版面分析中,会对图片里的区域进行分类,根据根据类别进行对于的ocr流程。
|
||||
|
||||
目前layoutparser会输出五个类别:
|
||||
1. Text
|
||||
2. Title
|
||||
3. Figure
|
||||
4. List
|
||||
5. Table
|
||||
|
||||
1-4类走传统的OCR流程,5走表格的OCR流程。
|
||||
|
||||
## 2. LayoutParser
|
||||
|
||||
[文档](layout/README.md)
|
||||
|
||||
## 3. Table OCR
|
||||
|
||||
[文档](table/README_ch.md)
|
||||
|
||||
## 4. 预测引擎推理
|
||||
|
||||
使用如下命令即可完成预测引擎的推理
|
||||
```python
|
||||
python3 table/predict_system.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
运行完成后,每张图片会output字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,excel文件名为表格在图片里的坐标。
|
||||
|
||||
## 5. PaddleStructure whl包介绍
|
||||
|
||||
### 5.1 使用
|
||||
|
||||
5.1.1 代码使用
|
||||
```python
|
||||
import os
|
||||
import cv2
|
||||
from paddlestructure import PaddleStructure,draw_result,save_res
|
||||
|
||||
table_engine = PaddleStructure(show_log=True)
|
||||
|
||||
save_folder = './output/table'
|
||||
img_path = '../doc/table/1.png'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
save_res(result, save_folder,os.path.basename(img_path).split('.')[0])
|
||||
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = 'path/tp/PaddleOCR/doc/fonts/simfang.ttf'
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_result(image, result,font_path=font_path)
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
5.1.2 命令行使用
|
||||
```bash
|
||||
paddlestructure --image_dir=../doc/table/1.png
|
||||
```
|
||||
|
||||
### 参数说明
|
||||
大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|------------------------|------------------------------------------------------|------------------|
|
||||
| output | excel和识别结果保存的地址 | ./output/table |
|
||||
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
|
||||
| table_model_dir | 表格结构模型 inference 模型地址 | None |
|
||||
| table_char_type | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.tx |
|
||||
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
# 版面分析使用说明
|
||||
|
||||
* [1. 安装whl包](#安装whl包)
|
||||
* [2. 使用](#使用)
|
||||
* [3. 后处理](#后处理)
|
||||
* [4. 指标](#指标)
|
||||
* [5. 训练版面分析模型](#训练版面分析模型)
|
||||
|
||||
<a name="安装whl包"></a>
|
||||
|
||||
## 1. 安装whl包
|
||||
```bash
|
||||
wget https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
pip install -U layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
<a name="使用"></a>
|
||||
|
||||
## 2. 使用
|
||||
|
||||
使用layoutparser识别给定文档的布局:
|
||||
|
||||
```python
|
||||
import layoutparser as lp
|
||||
image = cv2.imread("imags/paper-image.jpg")
|
||||
image = image[..., ::-1]
|
||||
|
||||
# 加载模型
|
||||
model = lp.PaddleDetectionLayoutModel(config_path="lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
||||
threshold=0.5,
|
||||
label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},
|
||||
enforce_cpu=False,
|
||||
enable_mkldnn=True)
|
||||
# 检测
|
||||
layout = model.detect(image)
|
||||
|
||||
# 显示结果
|
||||
lp.draw_box(image, layout, box_width=3, show_element_type=True)
|
||||
```
|
||||
|
||||
下图展示了结果,不同颜色的检测框表示不同的类别,并通过`show_element_type`在框的左上角显示具体类别:
|
||||
|
||||
<div align="center">
|
||||
<img src="../../doc/table/result_all.jpg" width = "600" />
|
||||
</div>
|
||||
|
||||
`PaddleDetectionLayoutModel`函数参数说明如下:
|
||||
|
||||
| 参数 | 含义 | 默认值 | 备注 |
|
||||
| :------------: | :-------------------------: | :---------: | :----------------------------------------------------------: |
|
||||
| config_path | 模型配置路径 | None | 指定config_path会自动下载模型(仅第一次,之后模型存在,不会再下载) |
|
||||
| model_path | 模型路径 | None | 本地模型路径,config_path和model_path必须设置一个,不能同时为None |
|
||||
| threshold | 预测得分的阈值 | 0.5 | \ |
|
||||
| input_shape | reshape之后图片尺寸 | [3,640,640] | \ |
|
||||
| batch_size | 测试batch size | 1 | \ |
|
||||
| label_map | 类别映射表 | None | 设置config_path时,可以为None,根据数据集名称自动获取label_map |
|
||||
| enforce_cpu | 代码是否使用CPU运行 | False | 设置为False表示使用GPU,True表示强制使用CPU |
|
||||
| enforce_mkldnn | CPU预测中是否开启MKLDNN加速 | True | \ |
|
||||
| thread_num | 设置CPU线程数 | 10 | \ |
|
||||
|
||||
目前支持以下几种模型配置和label map,您可以通过修改 `--config_path`和 `--label_map`使用这些模型,从而检测不同类型的内容:
|
||||
|
||||
| dataset | config_path | label_map |
|
||||
| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------------------------------------- |
|
||||
| [TableBank](https://doc-analysis.github.io/tablebank-page/index.html) word | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_word/config | {0:"Table"} |
|
||||
| TableBank latex | lp://TableBank/ppyolov2_r50vd_dcn_365e_tableBank_latex/config | {0:"Table"} |
|
||||
| [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config | {0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"} |
|
||||
|
||||
* TableBank word和TableBank latex分别在word文档、latex文档数据集训练;
|
||||
* 下载TableBank数据集同时包含word和latex。
|
||||
|
||||
<a name="后处理"></a>
|
||||
|
||||
## 3. 后处理
|
||||
|
||||
版面分析检测包含多个类别,如果只想获取指定类别(如"Text"类别)的检测框、可以使用下述代码:
|
||||
|
||||
```python
|
||||
# 首先过滤特定文本类型的区域
|
||||
text_blocks = lp.Layout([b for b in layout if b.type=='Text'])
|
||||
figure_blocks = lp.Layout([b for b in layout if b.type=='Figure'])
|
||||
|
||||
# 因为在图像区域内可能检测到文本区域,所以只需要删除它们
|
||||
text_blocks = lp.Layout([b for b in text_blocks \
|
||||
if not any(b.is_in(b_fig) for b_fig in figure_blocks)])
|
||||
|
||||
# 对文本区域排序并分配id
|
||||
h, w = image.shape[:2]
|
||||
|
||||
left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)
|
||||
|
||||
left_blocks = text_blocks.filter_by(left_interval, center=True)
|
||||
left_blocks.sort(key = lambda b:b.coordinates[1])
|
||||
|
||||
right_blocks = [b for b in text_blocks if b not in left_blocks]
|
||||
right_blocks.sort(key = lambda b:b.coordinates[1])
|
||||
|
||||
# 最终合并两个列表,并按顺序添加索引
|
||||
text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])
|
||||
|
||||
# 显示结果
|
||||
lp.draw_box(image, text_blocks,
|
||||
box_width=3,
|
||||
show_element_id=True)
|
||||
```
|
||||
|
||||
显示只有"Text"类别的结果:
|
||||
|
||||
<div align="center">
|
||||
<img src="../../doc/table/result_text.jpg" width = "600" />
|
||||
</div>
|
||||
|
||||
<a name="指标"></a>
|
||||
|
||||
## 4. 指标
|
||||
|
||||
| Dataset | mAP | CPU time cost | GPU time cost |
|
||||
| --------- | ---- | ------------- | ------------- |
|
||||
| PubLayNet | 93.6 | 1713.7ms | 66.6ms |
|
||||
| TableBank | 96.2 | 1968.4ms | 65.1ms |
|
||||
|
||||
**Envrionment:**
|
||||
|
||||
**CPU:** Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz,24core
|
||||
|
||||
**GPU:** a single NVIDIA Tesla P40
|
||||
|
||||
<a name="训练版面分析模型"></a>
|
||||
|
||||
## 5. 训练版面分析模型
|
||||
|
||||
上述模型基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection) 训练,如果您想训练自己的版面分析模型,请参考:[train_layoutparser_model](train_layoutparser_model.md)
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
# 训练版面分析
|
||||
|
||||
* [1. 安装](#安装)
|
||||
* [1.1 环境要求](#环境要求)
|
||||
* [1.2 安装PaddleDetection](#安装PaddleDetection)
|
||||
* [2. 准备数据](#准备数据)
|
||||
* [3. 配置文件改动和说明](#配置文件改动和说明)
|
||||
* [4. PaddleDetection训练](#训练)
|
||||
* [5. PaddleDetection预测](#预测)
|
||||
* [6. 预测部署](#预测部署)
|
||||
* [6.1 模型导出](#模型导出)
|
||||
* [6.2 layout parser预测](#layout_parser预测)
|
||||
|
||||
<a name="安装"></a>
|
||||
|
||||
## 1. 安装
|
||||
|
||||
<a name="环境要求"></a>
|
||||
|
||||
### 1.1 环境要求
|
||||
|
||||
- PaddlePaddle 2.1
|
||||
- OS 64 bit
|
||||
- Python 3(3.5.1+/3.6/3.7/3.8/3.9),64 bit
|
||||
- pip/pip3(9.0.1+), 64 bit
|
||||
- CUDA >= 10.1
|
||||
- cuDNN >= 7.6
|
||||
|
||||
<a name="安装PaddleDetection"></a>
|
||||
|
||||
### 1.2 安装PaddleDetection
|
||||
|
||||
```bash
|
||||
# 克隆PaddleDetection仓库
|
||||
cd <path/to/clone/PaddleDetection>
|
||||
git clone https://github.com/PaddlePaddle/PaddleDetection.git
|
||||
|
||||
cd PaddleDetection
|
||||
# 安装其他依赖
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
更多安装教程,请参考: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
|
||||
|
||||
<a name="数据准备"></a>
|
||||
|
||||
## 2. 准备数据
|
||||
|
||||
下载 [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) 数据集:
|
||||
|
||||
```bash
|
||||
cd PaddleDetection/dataset/
|
||||
mkdir publaynet
|
||||
# 执行命令,下载
|
||||
wget -O publaynet.tar.gz https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz?_ga=2.104193024.1076900768.1622560733-649911202.1622560733
|
||||
# 解压
|
||||
tar -xvf publaynet.tar.gz
|
||||
```
|
||||
|
||||
解压之后PubLayNet目录结构:
|
||||
|
||||
| File or Folder | Description | num |
|
||||
| :------------- | :----------------------------------------------- | ------- |
|
||||
| `train/` | Images in the training subset | 335,703 |
|
||||
| `val/` | Images in the validation subset | 11,245 |
|
||||
| `test/` | Images in the testing subset | 11,405 |
|
||||
| `train.json` | Annotations for training images | |
|
||||
| `val.json` | Annotations for validation images | |
|
||||
| `LICENSE.txt` | Plaintext version of the CDLA-Permissive license | |
|
||||
| `README.txt` | Text file with the file names and description | |
|
||||
|
||||
如果使用其它数据集,请参考[准备训练数据](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/PrepareDataSet.md)
|
||||
|
||||
<a name="配置文件改动和说明"></a>
|
||||
|
||||
## 3. 配置文件改动和说明
|
||||
|
||||
我们使用 `configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml`配置进行训练,配置文件摘要如下:
|
||||
|
||||
<div align='center'>
|
||||
<img src='../../doc/table/PaddleDetection_config.png' width='600px'/>
|
||||
</div>
|
||||
|
||||
从上图看到 `ppyolov2_r50vd_dcn_365e_coco.yml` 配置需要依赖其他的配置文件,在该例子中需要依赖:
|
||||
|
||||
```
|
||||
coco_detection.yml:主要说明了训练数据和验证数据的路径
|
||||
|
||||
runtime.yml:主要说明了公共的运行参数,比如是否使用GPU、每多少个epoch存储checkpoint等
|
||||
|
||||
optimizer_365e.yml:主要说明了学习率和优化器的配置
|
||||
|
||||
ppyolov2_r50vd_dcn.yml:主要说明模型和主干网络的情况
|
||||
|
||||
ppyolov2_reader.yml:主要说明数据读取器配置,如batch size,并发加载子进程数等,同时包含读取后预处理操作,如resize、数据增强等等
|
||||
```
|
||||
|
||||
根据实际情况,修改上述文件,比如数据集路径、batch size等。
|
||||
|
||||
<a name="训练"></a>
|
||||
|
||||
## 4. PaddleDetection训练
|
||||
|
||||
PaddleDetection提供了单卡/多卡训练模式,满足用户多种训练需求
|
||||
|
||||
* GPU 单卡训练
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0 #windows和Mac下不需要执行该命令
|
||||
python tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml
|
||||
```
|
||||
|
||||
* GPU多卡训练
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval
|
||||
```
|
||||
|
||||
--eval:表示边训练边验证
|
||||
|
||||
* 模型恢复训练
|
||||
|
||||
在日常训练过程中,有的用户由于一些原因导致训练中断,用户可以使用-r的命令恢复训练:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --eval -r output/ppyolov2_r50vd_dcn_365e_coco/10000
|
||||
```
|
||||
|
||||
注意:如果遇到 "`Out of memory error`" 问题, 尝试在 `ppyolov2_reader.yml` 文件中调小`batch_size`
|
||||
|
||||
<a name="预测"></a>
|
||||
|
||||
## 5. PaddleDetection预测
|
||||
|
||||
设置参数,使用PaddleDetection预测:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer_img=images/paper-image.jpg --output_dir=infer_output/ --draw_threshold=0.5 -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final --use_vdl=Ture
|
||||
```
|
||||
|
||||
`--draw_threshold` 是个可选参数. 根据 [NMS](https://ieeexplore.ieee.org/document/1699659) 的计算,不同阈值会产生不同的结果 `keep_top_k`表示设置输出目标的最大数量,默认值为100,用户可以根据自己的实际情况进行设定。
|
||||
|
||||
<a name="预测部署"></a>
|
||||
|
||||
## 6. 预测部署
|
||||
|
||||
在layout parser中使用自己训练好的模型,
|
||||
|
||||
<a name="模型导出"></a>
|
||||
|
||||
### 6.1 模型导出
|
||||
|
||||
在模型训练过程中保存的模型文件是包含前向预测和反向传播的过程,在实际的工业部署则不需要反向传播,因此需要将模型进行导成部署需要的模型格式。 在PaddleDetection中提供了 `tools/export_model.py`脚本来导出模型。
|
||||
|
||||
导出模型名称默认是`model.*`,layout parser代码模型名称是`inference.*`, 所以修改[PaddleDetection/ppdet/engine/trainer.py ](https://github.com/PaddlePaddle/PaddleDetection/blob/b87a1ea86fa18ce69e44a17ad1b49c1326f19ff9/ppdet/engine/trainer.py#L512) (点开链接查看详细代码行),将`model`改为`inference`即可。
|
||||
|
||||
执行导出模型脚本:
|
||||
|
||||
```bash
|
||||
python tools/export_model.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --output_dir=./inference -o weights=output/ppyolov2_r50vd_dcn_365e_coco/model_final.pdparams
|
||||
```
|
||||
|
||||
预测模型会导出到`inference/ppyolov2_r50vd_dcn_365e_coco`目录下,分别为`infer_cfg.yml`(预测不需要), `inference.pdiparams`, `inference.pdiparams.info`,`inference.pdmodel` 。
|
||||
|
||||
更多模型导出教程,请参考:[EXPORT_MODEL](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/deploy/EXPORT_MODEL.md)
|
||||
|
||||
<a name="layout parser预测"></a>
|
||||
|
||||
### 6.2 layout_parser预测
|
||||
|
||||
`model_path`指定训练好的模型路径,使用layout parser进行预测:
|
||||
|
||||
```bash
|
||||
import layoutparser as lp
|
||||
model = lp.PaddleDetectionLayoutModel(model_path="inference/ppyolov2_r50vd_dcn_365e_coco", threshold=0.5,label_map={0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"},enforce_cpu=True,enable_mkldnn=True)
|
||||
```
|
||||
|
||||
|
||||
|
||||
***
|
||||
|
||||
更多PaddleDetection训练教程,请参考:[PaddleDetection训练](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/GETTING_STARTED_cn.md)
|
||||
|
||||
***
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
from test1.predict_system import OCRSystem, save_res
|
||||
from test1.table.predict_table import to_excel
|
||||
from test1.utility import init_args, draw_result
|
||||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
|
||||
|
||||
__all__ = ['PaddleStructure', 'draw_result', 'save_res']
|
||||
|
||||
VERSION = '2.1'
|
||||
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
|
||||
|
||||
model_urls = {
|
||||
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
|
||||
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
|
||||
'table': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
|
||||
|
||||
}
|
||||
|
||||
|
||||
def parse_args(mMain=True):
|
||||
import argparse
|
||||
parser = init_args()
|
||||
parser.add_help = mMain
|
||||
|
||||
for action in parser._actions:
|
||||
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
|
||||
action.default = None
|
||||
if mMain:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
inference_args_dict = {}
|
||||
for action in parser._actions:
|
||||
inference_args_dict[action.dest] = action.default
|
||||
return argparse.Namespace(**inference_args_dict)
|
||||
|
||||
|
||||
class PaddleStructure(OCRSystem):
|
||||
def __init__(self, **kwargs):
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
params.use_angle_cls = False
|
||||
# init model dir
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'det'),
|
||||
model_urls['det'])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'rec'),
|
||||
model_urls['rec'])
|
||||
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'table'),
|
||||
model_urls['table'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.table_model_dir, table_url)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_type = 'EN'
|
||||
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
|
||||
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
|
||||
else:
|
||||
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
|
||||
if params.table_char_dict_path is None:
|
||||
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
|
||||
params.table_char_dict_path = str(
|
||||
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||
else:
|
||||
params.table_char_dict_path = str(
|
||||
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||
|
||||
print(params)
|
||||
super().__init__(params)
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, str):
|
||||
# download net image
|
||||
if img.startswith('http'):
|
||||
download_with_progressbar(img, 'tmp.jpg')
|
||||
img = 'tmp.jpg'
|
||||
image_file = img
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
with open(image_file, 'rb') as f:
|
||||
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
res = super().__call__(img)
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
save_folder = args.output
|
||||
if image_dir.startswith('http'):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
if len(image_file_list) == 0:
|
||||
logger.error('no images find in {}'.format(args.image_dir))
|
||||
return
|
||||
|
||||
structure_engine = PaddleStructure(**(args.__dict__))
|
||||
for img_path in image_file_list:
|
||||
img_name = os.path.basename(img_path).split('.')[0]
|
||||
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||
result = structure_engine(img_path)
|
||||
for item in result:
|
||||
logger.info(item['res'])
|
||||
save_res(result, save_folder, img_name)
|
||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
|
||||
import layoutparser as lp
|
||||
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from test1.table.predict_table import TableSystem, to_excel
|
||||
from test1.utility import parse_args, draw_result
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class OCRSystem(object):
|
||||
def __init__(self, args):
|
||||
args.det_limit_type = 'resize_long'
|
||||
args.drop_score = 0
|
||||
if not args.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
self.text_system = TextSystem(args)
|
||||
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
||||
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
||||
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
||||
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||
res_list = []
|
||||
for region in layout_res:
|
||||
x1, y1, x2, y2 = region.coordinates
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||
if region.type == 'Table':
|
||||
res = self.table_system(roi_img)
|
||||
else:
|
||||
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||
filter_boxes = [x + [x1, y1] for x in filter_boxes]
|
||||
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
|
||||
|
||||
res = (filter_boxes, filter_rec_res)
|
||||
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
|
||||
return res_list
|
||||
|
||||
|
||||
def save_res(res, save_folder, img_name):
|
||||
excel_save_folder = os.path.join(save_folder, img_name)
|
||||
os.makedirs(excel_save_folder, exist_ok=True)
|
||||
# save res
|
||||
for region in res:
|
||||
if region['type'] == 'Table':
|
||||
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
||||
to_excel(region['res'], excel_path)
|
||||
elif region['type'] == 'Figure':
|
||||
pass
|
||||
else:
|
||||
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
|
||||
for box, rec_res in zip(region['res'][0], region['res'][1]):
|
||||
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
save_folder = args.output
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
|
||||
structure_sys = OCRSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
img_name = os.path.basename(image_file).split('.')[0]
|
||||
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
res = structure_sys(img)
|
||||
save_res(res, save_folder, img_name)
|
||||
draw_img = draw_result(img, res, args.vis_font_path)
|
||||
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
|
||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.use_mp:
|
||||
p_list = []
|
||||
total_process_num = args.total_process_num
|
||||
for process_id in range(total_process_num):
|
||||
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||
"--process_id={}".format(process_id),
|
||||
"--use_mp={}".format(False)
|
||||
]
|
||||
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||
p_list.append(p)
|
||||
for p in p_list:
|
||||
p.wait()
|
||||
else:
|
||||
main(args)
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
from setuptools import setup
|
||||
from io import open
|
||||
import shutil
|
||||
|
||||
with open('../requirements.txt', encoding="utf-8-sig") as f:
|
||||
requirements = f.readlines()
|
||||
requirements.append('tqdm')
|
||||
|
||||
|
||||
def readme():
|
||||
with open('api_ch.md', encoding="utf-8-sig") as f:
|
||||
README = f.read()
|
||||
return README
|
||||
|
||||
|
||||
shutil.copytree('./table', './test1/table')
|
||||
shutil.copyfile('./predict_system.py', './test1/predict_system.py')
|
||||
shutil.copyfile('./utility.py', './test1/utility.py')
|
||||
shutil.copytree('../ppocr', './ppocr')
|
||||
shutil.copytree('../tools', './tools')
|
||||
shutil.copyfile('../LICENSE', './LICENSE')
|
||||
|
||||
setup(
|
||||
name='paddlestructure',
|
||||
packages=['paddlestructure'],
|
||||
package_dir={'paddlestructure': ''},
|
||||
include_package_data=True,
|
||||
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
|
||||
version='1.0',
|
||||
install_requires=requirements,
|
||||
license='Apache License 2.0',
|
||||
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||
long_description=readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/PaddlePaddle/PaddleOCR',
|
||||
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
|
||||
keywords=[
|
||||
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
|
||||
],
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers', 'Operating System :: OS Independent',
|
||||
'Natural Language :: Chinese (Simplified)',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.2',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
|
||||
], )
|
||||
|
||||
shutil.rmtree('ppocr')
|
||||
shutil.rmtree('tools')
|
||||
shutil.rmtree('test1')
|
||||
os.remove('LICENSE')
|
|
@ -0,0 +1,49 @@
|
|||
# Table structure and content prediction
|
||||
|
||||
## 1. pipeline
|
||||
The ocr of the table mainly contains three models
|
||||
1. Single line text detection-DB
|
||||
2. Single line text recognition-CRNN
|
||||
3. Table structure and cell coordinate prediction-RARE
|
||||
|
||||
The table ocr flow chart is as follows
|
||||
|
||||
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
|
||||
|
||||
1. The coordinates of single-line text is detected by DB model, and then sends it to the recognition model to get the recognition result.
|
||||
2. The table structure and cell coordinates is predicted by RARE model.
|
||||
3. The recognition result of the cell is combined by the coordinates, recognition result of the single line and the coordinates of the cell.
|
||||
4. The cell recognition result and the table structure together construct the html string of the table.
|
||||
|
||||
## 2. How to use
|
||||
|
||||
|
||||
### 2.1 Train
|
||||
TBD
|
||||
|
||||
### 2.2 Eval
|
||||
First cd to the PaddleOCR/ppstructure directory
|
||||
|
||||
The table uses TEDS (Tree-Edit-Distance-based Similarity) as the evaluation metric of the model. Before the model evaluation, the three models in the pipeline need to be exported as inference models (we have provided them), and the gt for evaluation needs to be prepared. Examples of gt are as follows:
|
||||
```json
|
||||
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
|
||||
```
|
||||
In gt json, the key is the image name, the value is the corresponding gt, and gt is a list composed of four items, and each item is
|
||||
1. HTML string list of table structure
|
||||
2. The coordinates of each cell (not including the empty text in the cell)
|
||||
3. The text information in each cell (not including the empty text in the cell)
|
||||
4. The text information in each cell (including the empty text in the cell)
|
||||
|
||||
Use the following command to evaluate. After the evaluation is completed, the teds indicator will be output.
|
||||
```python
|
||||
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||
```
|
||||
|
||||
|
||||
### 2.3 Inference
|
||||
First cd to the PaddleOCR/ppstructure directory
|
||||
|
||||
```python
|
||||
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
After running, the excel sheet of each picture will be saved in the directory specified by the output field
|
|
@ -0,0 +1,49 @@
|
|||
# 表格结构和内容预测
|
||||
|
||||
## 1. pipeline
|
||||
表格的ocr主要包含三个模型
|
||||
1. 单行文本检测-DB
|
||||
2. 单行文本识别-CRNN
|
||||
3. 表格结构和cell坐标预测-RARE
|
||||
|
||||
具体流程图如下
|
||||
|
||||
![tableocr_pipeline](../../doc/table/tableocr_pipeline.png)
|
||||
|
||||
1. 图片由单行文字检测检测模型到单行文字的坐标,然后送入识别模型拿到识别结果。
|
||||
2. 图片由表格结构和cell坐标预测模型拿到表格的结构信息和单元格的坐标信息。
|
||||
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
|
||||
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
|
||||
|
||||
## 2. 使用
|
||||
|
||||
|
||||
### 2.1 训练
|
||||
TBD
|
||||
|
||||
### 2.2 评估
|
||||
先cd到PaddleOCR/ppstructure目录下
|
||||
|
||||
表格使用 TEDS(Tree-Edit-Distance-based Similarity) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
|
||||
```json
|
||||
{"PMC4289340_004_00.png": [["<html>", "<body>", "<table>", "<thead>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</thead>", "<tbody>", "<tr>", "<td>", "</td>", "<td>", "</td>", "<td>", "</td>", "</tr>", "</tbody>", "</table>", "</body>", "</html>"], [[1, 4, 29, 13], [137, 4, 161, 13], [215, 4, 236, 13], [1, 17, 30, 27], [137, 17, 147, 27], [215, 17, 225, 27]], [["<b>", "F", "e", "a", "t", "u", "r", "e", "</b>"], ["<b>", "G", "b", "3", " ", "+", "</b>"], ["<b>", "G", "b", "3", " ", "-", "</b>"], ["<b>", "P", "a", "t", "i", "e", "n", "t", "s", "</b>"], ["6", "2"], ["4", "5"]]]}
|
||||
```
|
||||
json 中,key为图片名,value为对于的gt,gt是一个由四个item组成的list,每个item分别为
|
||||
1. 表格结构的html字符串list
|
||||
2. 每个cell的坐标 (不包括cell里文字为空的)
|
||||
3. 每个cell里的文字信息 (不包括cell里文字为空的)
|
||||
4. 每个cell里的文字信息 (包括cell里文字为空的)
|
||||
|
||||
准备完成后使用如下命令进行评估,评估完成后会输出teds指标。
|
||||
```python
|
||||
python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||
```
|
||||
|
||||
|
||||
### 2.3 预测
|
||||
先cd到PaddleOCR/ppstructure目录下
|
||||
|
||||
```python
|
||||
python3 table/predict_table.py --det_model_dir=path/to/det_model_dir --rec_model_dir=path/to/rec_model_dir --table_model_dir=path/to/table_model_dir --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
|
|
@ -0,0 +1,13 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import cv2
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from test1.table.table_metric import TEDS
|
||||
from test1.table.predict_table import TableSystem
|
||||
from test1.utility import init_args
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = init_args()
|
||||
parser.add_argument("--gt_path", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
teds = TEDS(n_jobs=16)
|
||||
|
||||
text_sys = TableSystem(args)
|
||||
jsons_gt = json.load(open(gt_path)) # gt
|
||||
pred_htmls = []
|
||||
gt_htmls = []
|
||||
for img_name in tqdm(jsons_gt):
|
||||
# read image
|
||||
img = cv2.imread(os.path.join(img_root,img_name))
|
||||
pred_html = text_sys(img)
|
||||
pred_htmls.append(pred_html)
|
||||
|
||||
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
|
||||
gt_html, gt = get_gt_html(gt_structures, contents_with_block)
|
||||
gt_htmls.append(gt_html)
|
||||
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
|
||||
logger.info('teds:', sum(scores) / len(scores))
|
||||
|
||||
|
||||
def get_gt_html(gt_structures, contents_with_block):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in gt_structures:
|
||||
if '</td>' in tag:
|
||||
if contents_with_block[td_index] != []:
|
||||
end_html.extend(contents_with_block[td_index])
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args.gt_path,args.image_dir, args)
|
|
@ -0,0 +1,192 @@
|
|||
import json
|
||||
def distance(box_1, box_2):
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4- x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
|
||||
def compute_iou(rec1, rec2):
|
||||
"""
|
||||
computing IoU
|
||||
:param rec1: (y0, x0, y1, x1), which reflects
|
||||
(top, left, bottom, right)
|
||||
:param rec2: (y0, x0, y1, x1)
|
||||
:return: scala value of IoU
|
||||
"""
|
||||
# computing area of each rectangles
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
# computing the sum_area
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
# find the each edge of intersect rectangle
|
||||
left_line = max(rec1[1], rec2[1])
|
||||
right_line = min(rec1[3], rec2[3])
|
||||
top_line = max(rec1[0], rec2[0])
|
||||
bottom_line = min(rec1[2], rec2[2])
|
||||
|
||||
# judge if there is an intersect
|
||||
if left_line >= right_line or top_line >= bottom_line:
|
||||
return 0.0
|
||||
else:
|
||||
intersect = (right_line - left_line) * (bottom_line - top_line)
|
||||
return (intersect / (sum_area - intersect))*1.0
|
||||
|
||||
|
||||
|
||||
def matcher_merge(ocr_bboxes, pred_bboxes):
|
||||
all_dis = []
|
||||
ious = []
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(ocr_bboxes):
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
# compute l1 distence and IOU between two boxes
|
||||
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
|
||||
sorted_distances = distances.copy()
|
||||
# select nearest cell
|
||||
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
|
||||
def complex_num(pred_bboxes):
|
||||
complex_nums = []
|
||||
for bbox in pred_bboxes:
|
||||
distances = []
|
||||
temp_ious = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
if bbox != pred_bbox:
|
||||
distances.append(distance(bbox, pred_bbox))
|
||||
temp_ious.append(compute_iou(bbox, pred_bbox))
|
||||
complex_nums.append(temp_ious[distances.index(min(distances))])
|
||||
return sum(complex_nums) / len(complex_nums)
|
||||
|
||||
def get_rows(pred_bboxes):
|
||||
pre_bbox = pred_bboxes[0]
|
||||
res = []
|
||||
step = 0
|
||||
for i in range(len(pred_bboxes)):
|
||||
bbox = pred_bboxes[i]
|
||||
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
|
||||
break
|
||||
else:
|
||||
res.append(bbox)
|
||||
step += 1
|
||||
for i in range(step):
|
||||
pred_bboxes.pop(0)
|
||||
return res, pred_bboxes
|
||||
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
||||
ys_1 = []
|
||||
ys_2 = []
|
||||
for box in pred_bboxes:
|
||||
ys_1.append(box[1])
|
||||
ys_2.append(box[3])
|
||||
min_y_1 = sum(ys_1) / len(ys_1)
|
||||
min_y_2 = sum(ys_2) / len(ys_2)
|
||||
re_boxes = []
|
||||
for box in pred_bboxes:
|
||||
box[1] = min_y_1
|
||||
box[3] = min_y_2
|
||||
re_boxes.append(box)
|
||||
return re_boxes
|
||||
|
||||
def matcher_refine_row(gt_bboxes, pred_bboxes):
|
||||
before_refine_pred_bboxes = pred_bboxes.copy()
|
||||
pred_bboxes = []
|
||||
while(len(before_refine_pred_bboxes) != 0):
|
||||
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
|
||||
print(row_bboxes)
|
||||
pred_bboxes.extend(refine_rows(row_bboxes))
|
||||
all_dis = []
|
||||
ious = []
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
distances = []
|
||||
#temp_ious = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
#temp_ious.append(compute_iou(gt_box, pred_box))
|
||||
#all_dis.append(min(distances))
|
||||
#ious.append(temp_ious[distances.index(min(distances))])
|
||||
if distances.index(min(distances)) not in matched.keys():
|
||||
matched[distances.index(min(distances))] = [i]
|
||||
else:
|
||||
matched[distances.index(min(distances))].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
|
||||
|
||||
|
||||
#先挑选出一行,再进行匹配
|
||||
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
gt_box_index = 0
|
||||
delete_gt_bboxes = gt_bboxes.copy()
|
||||
match_bboxes_ready = []
|
||||
matched = {}
|
||||
while(len(delete_gt_bboxes) != 0):
|
||||
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
|
||||
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
|
||||
if len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
print(row_bboxes)
|
||||
for i, gt_box in enumerate(row_bboxes):
|
||||
#print(gt_box)
|
||||
pred_distances = []
|
||||
distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print('index', index)
|
||||
if index not in matched.keys():
|
||||
matched[index] = [gt_box_index]
|
||||
else:
|
||||
matched[index].append(gt_box_index)
|
||||
gt_box_index += 1
|
||||
return matched
|
||||
|
||||
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
'''
|
||||
gt_bboxes: 排序后
|
||||
pred_bboxes:
|
||||
'''
|
||||
pre_bbox = gt_bboxes[0]
|
||||
matched = {}
|
||||
match_bboxes_ready = []
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
|
||||
pred_distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
distances = []
|
||||
gap_pre = gt_box[1] - pre_bbox[1]
|
||||
gap_pre_1 = gt_box[0] - pre_bbox[2]
|
||||
#print(gap_pre, len(pred_bboxes_rows))
|
||||
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(pred_bboxes_rows) == 1:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
|
||||
break
|
||||
#print(match_bboxes_ready)
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print(gt_box, index)
|
||||
#match_bboxes_ready.pop(distances.index(min(distances)))
|
||||
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
|
||||
if index not in matched.keys():
|
||||
matched[index] = [i]
|
||||
else:
|
||||
matched[index].append(i)
|
||||
pre_bbox = gt_box
|
||||
return matched
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
import paddle
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from test1.utility import parse_args
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TableStructurer(object):
|
||||
def __init__(self, args):
|
||||
pre_process_list = [{
|
||||
'ResizeTableImage': {
|
||||
'max_len': args.table_max_len
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'PaddingTableImage': None
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {
|
||||
'name': 'TableLabelDecode',
|
||||
"character_type": args.table_char_type,
|
||||
"character_dict_path": args.table_char_dict_path,
|
||||
}
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'table', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
img = data[0]
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
|
||||
preds = {}
|
||||
preds['structure_probs'] = outputs[1]
|
||||
preds['loc_preds'] = outputs[0]
|
||||
|
||||
post_result = self.postprocess_op(preds)
|
||||
|
||||
structure_str_list = post_result['structure_str_list']
|
||||
res_loc = post_result['res_loc']
|
||||
imgh, imgw = ori_im.shape[0:2]
|
||||
res_loc_final = []
|
||||
for rno in range(len(res_loc[0])):
|
||||
x0, y0, x1, y1 = res_loc[0][rno]
|
||||
left = max(int(imgw * x0), 0)
|
||||
top = max(int(imgh * y0), 0)
|
||||
right = min(int(imgw * x1), imgw - 1)
|
||||
bottom = min(int(imgh * y1), imgh - 1)
|
||||
res_loc_final.append([left, top, right, bottom])
|
||||
|
||||
structure_str_list = structure_str_list[0][:-1]
|
||||
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
|
||||
|
||||
elapse = time.time() - starttime
|
||||
return (structure_str_list, res_loc_final), elapse
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
table_structurer = TableStructurer(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
structure_res, elapse = table_structurer(img)
|
||||
|
||||
logger.info("result: {}".format(structure_res))
|
||||
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(parse_args())
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_det as predict_det
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from test1.table.matcher import distance, compute_iou
|
||||
from test1.utility import parse_args
|
||||
import test1.table.predict_structure as predict_strture
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def expand(pix, det_box, shape):
|
||||
x0, y0, x1, y1 = det_box
|
||||
# print(shape)
|
||||
h, w, c = shape
|
||||
tmp_x0 = x0 - pix
|
||||
tmp_x1 = x1 + pix
|
||||
tmp_y0 = y0 - pix
|
||||
tmp_y1 = y1 + pix
|
||||
x0_ = tmp_x0 if tmp_x0 >= 0 else 0
|
||||
x1_ = tmp_x1 if tmp_x1 <= w else w
|
||||
y0_ = tmp_y0 if tmp_y0 >= 0 else 0
|
||||
y1_ = tmp_y1 if tmp_y1 <= h else h
|
||||
return x0_, y0_, x1_, y1_
|
||||
|
||||
|
||||
class TableSystem(object):
|
||||
def __init__(self, args, text_detector=None, text_recognizer=None):
|
||||
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
|
||||
self.table_structurer = predict_strture.TableStructurer(args)
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
|
||||
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
|
||||
r_boxes = []
|
||||
for box in dt_boxes:
|
||||
x_min = box[:, 0].min() - 1
|
||||
x_max = box[:, 0].max() + 1
|
||||
y_min = box[:, 1].min() - 1
|
||||
y_max = box[:, 1].max() + 1
|
||||
box = [x_min, y_min, x_max, y_max]
|
||||
r_boxes.append(box)
|
||||
dt_boxes = np.array(r_boxes)
|
||||
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
img_crop_list = []
|
||||
|
||||
for i in range(len(dt_boxes)):
|
||||
det_box = dt_boxes[i]
|
||||
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
|
||||
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||
img_crop_list.append(text_rect)
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
|
||||
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||
return pred_html
|
||||
|
||||
def rebuild_table(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
|
||||
return pred_html, pred
|
||||
|
||||
def match_result(self, dt_boxes, pred_bboxes):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append(
|
||||
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
|
||||
sorted_distances = distances.copy()
|
||||
# 根据距离和IOU挑选最"近"的cell
|
||||
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if '</td>' in tag:
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
end_html.extend('<b>')
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
end_html.extend(content)
|
||||
if b_with:
|
||||
end_html.extend('</b>')
|
||||
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
return _boxes
|
||||
|
||||
|
||||
def to_excel(html_table, excel_path):
|
||||
from tablepyxl import tablepyxl
|
||||
tablepyxl.document_to_xl(html_table, excel_path)
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
text_sys = TableSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
pred_html = text_sys(img)
|
||||
|
||||
to_excel(pred_html, excel_path)
|
||||
logger.info('excel saved to {}'.format(excel_path))
|
||||
logger.info(pred_html)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.use_mp:
|
||||
p_list = []
|
||||
total_process_num = args.total_process_num
|
||||
for process_id in range(total_process_num):
|
||||
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||
"--process_id={}".format(process_id),
|
||||
"--use_mp={}".format(False)
|
||||
]
|
||||
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||
p_list.append(p)
|
||||
for p in p_list:
|
||||
p.wait()
|
||||
else:
|
||||
main(args)
|
|
@ -0,0 +1,16 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['TEDS']
|
||||
from .table_metric import TEDS
|
|
@ -0,0 +1,51 @@
|
|||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
|
||||
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
|
||||
"""
|
||||
A parallel version of the map function with a progress bar.
|
||||
Args:
|
||||
array (array-like): An array to iterate over.
|
||||
function (function): A python function to apply to the elements of array
|
||||
n_jobs (int, default=16): The number of cores to use
|
||||
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
|
||||
keyword arguments to function
|
||||
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
|
||||
Useful for catching bugs
|
||||
Returns:
|
||||
[function(array[0]), function(array[1]), ...]
|
||||
"""
|
||||
# We run the first few iterations serially to catch bugs
|
||||
if front_num > 0:
|
||||
front = [function(**a) if use_kwargs else function(a)
|
||||
for a in array[:front_num]]
|
||||
else:
|
||||
front = []
|
||||
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
|
||||
if n_jobs == 1:
|
||||
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
|
||||
# Assemble the workers
|
||||
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
|
||||
# Pass the elements of array into function
|
||||
if use_kwargs:
|
||||
futures = [pool.submit(function, **a) for a in array[front_num:]]
|
||||
else:
|
||||
futures = [pool.submit(function, a) for a in array[front_num:]]
|
||||
kwargs = {
|
||||
'total': len(futures),
|
||||
'unit': 'it',
|
||||
'unit_scale': True,
|
||||
'leave': True
|
||||
}
|
||||
# Print out the progress as tasks complete
|
||||
for f in tqdm(as_completed(futures), **kwargs):
|
||||
pass
|
||||
out = []
|
||||
# Get the results from the futures.
|
||||
for i, future in tqdm(enumerate(futures)):
|
||||
try:
|
||||
out.append(future.result())
|
||||
except Exception as e:
|
||||
out.append(e)
|
||||
return front + out
|
|
@ -0,0 +1,247 @@
|
|||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
#
|
||||
# This is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the Apache 2.0 License.
|
||||
#
|
||||
# This software is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# Apache 2.0 License for more details.
|
||||
|
||||
import distance
|
||||
from apted import APTED, Config
|
||||
from apted.helpers import Tree
|
||||
from lxml import etree, html
|
||||
from collections import deque
|
||||
from .parallel import parallel_process
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TableTree(Tree):
|
||||
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
||||
self.tag = tag
|
||||
self.colspan = colspan
|
||||
self.rowspan = rowspan
|
||||
self.content = content
|
||||
self.children = list(children)
|
||||
|
||||
def bracket(self):
|
||||
"""Show tree using brackets notation"""
|
||||
if self.tag == 'td':
|
||||
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
||||
(self.tag, self.colspan, self.rowspan, self.content)
|
||||
else:
|
||||
result = '"tag": %s' % self.tag
|
||||
for child in self.children:
|
||||
result += child.bracket()
|
||||
return "{{{}}}".format(result)
|
||||
|
||||
|
||||
class CustomConfig(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
#print(node1.tag)
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
return 1.
|
||||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
#print(node1.content, )
|
||||
return self.normalized_distance(node1.content, node2.content)
|
||||
return 0.
|
||||
|
||||
|
||||
|
||||
class CustomConfig_del_short(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
return 1.
|
||||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
#print('before')
|
||||
#print(node1.content, node2.content)
|
||||
#print('after')
|
||||
node1_content = node1.content
|
||||
node2_content = node2.content
|
||||
if len(node1_content) < 3:
|
||||
node1_content = ['####']
|
||||
if len(node2_content) < 3:
|
||||
node2_content = ['####']
|
||||
return self.normalized_distance(node1_content, node2_content)
|
||||
return 0.
|
||||
|
||||
class CustomConfig_del_block(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
return 1.
|
||||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
|
||||
node1_content = node1.content
|
||||
node2_content = node2.content
|
||||
while ' ' in node1_content:
|
||||
print(node1_content.index(' '))
|
||||
node1_content.pop(node1_content.index(' '))
|
||||
while ' ' in node2_content:
|
||||
print(node2_content.index(' '))
|
||||
node2_content.pop(node2_content.index(' '))
|
||||
return self.normalized_distance(node1_content, node2_content)
|
||||
return 0.
|
||||
|
||||
class TEDS(object):
|
||||
''' Tree Edit Distance basead Similarity
|
||||
'''
|
||||
|
||||
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
||||
assert isinstance(n_jobs, int) and (
|
||||
n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
||||
self.structure_only = structure_only
|
||||
self.n_jobs = n_jobs
|
||||
self.ignore_nodes = ignore_nodes
|
||||
self.__tokens__ = []
|
||||
|
||||
def tokenize(self, node):
|
||||
''' Tokenizes table cells
|
||||
'''
|
||||
self.__tokens__.append('<%s>' % node.tag)
|
||||
if node.text is not None:
|
||||
self.__tokens__ += list(node.text)
|
||||
for n in node.getchildren():
|
||||
self.tokenize(n)
|
||||
if node.tag != 'unk':
|
||||
self.__tokens__.append('</%s>' % node.tag)
|
||||
if node.tag != 'td' and node.tail is not None:
|
||||
self.__tokens__ += list(node.tail)
|
||||
|
||||
def load_html_tree(self, node, parent=None):
|
||||
''' Converts HTML tree to the format required by apted
|
||||
'''
|
||||
global __tokens__
|
||||
if node.tag == 'td':
|
||||
if self.structure_only:
|
||||
cell = []
|
||||
else:
|
||||
self.__tokens__ = []
|
||||
self.tokenize(node)
|
||||
cell = self.__tokens__[1:-1].copy()
|
||||
new_node = TableTree(node.tag,
|
||||
int(node.attrib.get('colspan', '1')),
|
||||
int(node.attrib.get('rowspan', '1')),
|
||||
cell, *deque())
|
||||
else:
|
||||
new_node = TableTree(node.tag, None, None, None, *deque())
|
||||
if parent is not None:
|
||||
parent.children.append(new_node)
|
||||
if node.tag != 'td':
|
||||
for n in node.getchildren():
|
||||
self.load_html_tree(n, new_node)
|
||||
if parent is None:
|
||||
return new_node
|
||||
|
||||
def evaluate(self, pred, true):
|
||||
''' Computes TEDS score between the prediction and the ground truth of a
|
||||
given sample
|
||||
'''
|
||||
if (not pred) or (not true):
|
||||
return 0.0
|
||||
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
||||
pred = html.fromstring(pred, parser=parser)
|
||||
true = html.fromstring(true, parser=parser)
|
||||
if pred.xpath('body/table') and true.xpath('body/table'):
|
||||
pred = pred.xpath('body/table')[0]
|
||||
true = true.xpath('body/table')[0]
|
||||
if self.ignore_nodes:
|
||||
etree.strip_tags(pred, *self.ignore_nodes)
|
||||
etree.strip_tags(true, *self.ignore_nodes)
|
||||
n_nodes_pred = len(pred.xpath(".//*"))
|
||||
n_nodes_true = len(true.xpath(".//*"))
|
||||
n_nodes = max(n_nodes_pred, n_nodes_true)
|
||||
tree_pred = self.load_html_tree(pred)
|
||||
tree_true = self.load_html_tree(true)
|
||||
distance = APTED(tree_pred, tree_true,
|
||||
CustomConfig()).compute_edit_distance()
|
||||
return 1.0 - (float(distance) / n_nodes)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def batch_evaluate(self, pred_json, true_json):
|
||||
''' Computes TEDS score between the prediction and the ground truth of
|
||||
a batch of samples
|
||||
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
||||
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
||||
@output: {'FILENAME': 'TEDS SCORE', ...}
|
||||
'''
|
||||
samples = true_json.keys()
|
||||
if self.n_jobs == 1:
|
||||
scores = [self.evaluate(pred_json.get(
|
||||
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
||||
else:
|
||||
inputs = [{'pred': pred_json.get(
|
||||
filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
||||
scores = parallel_process(
|
||||
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
||||
scores = dict(zip(samples, scores))
|
||||
return scores
|
||||
|
||||
def batch_evaluate_html(self, pred_htmls, true_htmls):
|
||||
''' Computes TEDS score between the prediction and the ground truth of
|
||||
a batch of samples
|
||||
'''
|
||||
if self.n_jobs == 1:
|
||||
scores = [self.evaluate(pred_html, true_html) for (
|
||||
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
||||
else:
|
||||
inputs = [{"pred": pred_html, "true": true_html} for(
|
||||
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
||||
|
||||
scores = parallel_process(
|
||||
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import json
|
||||
import pprint
|
||||
with open('sample_pred.json') as fp:
|
||||
pred_json = json.load(fp)
|
||||
with open('sample_gt.json') as fp:
|
||||
true_json = json.load(fp)
|
||||
teds = TEDS(n_jobs=4)
|
||||
scores = teds.batch_evaluate(pred_json, true_json)
|
||||
pp = pprint.PrettyPrinter()
|
||||
pp.pprint(scores)
|
|
@ -0,0 +1,13 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,283 @@
|
|||
# This is where we handle translating css styles into openpyxl styles
|
||||
# and cascading those from parent to child in the dom.
|
||||
|
||||
from openpyxl.cell import cell
|
||||
from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
|
||||
from openpyxl.styles.fills import FILL_SOLID
|
||||
from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
|
||||
from openpyxl.styles.colors import BLACK
|
||||
|
||||
FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
|
||||
|
||||
|
||||
def colormap(color):
|
||||
"""
|
||||
Convenience for looking up known colors
|
||||
"""
|
||||
cmap = {'black': BLACK}
|
||||
return cmap.get(color, color)
|
||||
|
||||
|
||||
def style_string_to_dict(style):
|
||||
"""
|
||||
Convert css style string to a python dictionary
|
||||
"""
|
||||
def clean_split(string, delim):
|
||||
return (s.strip() for s in string.split(delim))
|
||||
styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
|
||||
return dict(styles)
|
||||
|
||||
|
||||
def get_side(style, name):
|
||||
return {'border_style': style.get('border-{}-style'.format(name)),
|
||||
'color': colormap(style.get('border-{}-color'.format(name)))}
|
||||
|
||||
known_styles = {}
|
||||
|
||||
|
||||
def style_dict_to_named_style(style_dict, number_format=None):
|
||||
"""
|
||||
Change css style (stored in a python dictionary) to openpyxl NamedStyle
|
||||
"""
|
||||
|
||||
style_and_format_string = str({
|
||||
'style_dict': style_dict,
|
||||
'parent': style_dict.parent,
|
||||
'number_format': number_format,
|
||||
})
|
||||
|
||||
if style_and_format_string not in known_styles:
|
||||
# Font
|
||||
font = Font(bold=style_dict.get('font-weight') == 'bold',
|
||||
color=style_dict.get_color('color', None),
|
||||
size=style_dict.get('font-size'))
|
||||
|
||||
# Alignment
|
||||
alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
|
||||
vertical=style_dict.get('vertical-align'),
|
||||
wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
|
||||
|
||||
# Fill
|
||||
bg_color = style_dict.get_color('background-color')
|
||||
fg_color = style_dict.get_color('foreground-color', Color())
|
||||
fill_type = style_dict.get('fill-type')
|
||||
if bg_color and bg_color != 'transparent':
|
||||
fill = PatternFill(fill_type=fill_type or FILL_SOLID,
|
||||
start_color=bg_color,
|
||||
end_color=fg_color)
|
||||
else:
|
||||
fill = PatternFill()
|
||||
|
||||
# Border
|
||||
border = Border(left=Side(**get_side(style_dict, 'left')),
|
||||
right=Side(**get_side(style_dict, 'right')),
|
||||
top=Side(**get_side(style_dict, 'top')),
|
||||
bottom=Side(**get_side(style_dict, 'bottom')),
|
||||
diagonal=Side(**get_side(style_dict, 'diagonal')),
|
||||
diagonal_direction=None,
|
||||
outline=Side(**get_side(style_dict, 'outline')),
|
||||
vertical=None,
|
||||
horizontal=None)
|
||||
|
||||
name = 'Style {}'.format(len(known_styles) + 1)
|
||||
|
||||
pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
|
||||
number_format=number_format)
|
||||
|
||||
known_styles[style_and_format_string] = pyxl_style
|
||||
|
||||
return known_styles[style_and_format_string]
|
||||
|
||||
|
||||
class StyleDict(dict):
|
||||
"""
|
||||
It's like a dictionary, but it looks for items in the parent dictionary
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.parent = kwargs.pop('parent', None)
|
||||
super(StyleDict, self).__init__(*args, **kwargs)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item in self:
|
||||
return super(StyleDict, self).__getitem__(item)
|
||||
elif self.parent:
|
||||
return self.parent[item]
|
||||
else:
|
||||
raise KeyError('{} not found'.format(item))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple([(k, self.get(k)) for k in self._keys()]))
|
||||
|
||||
# Yielding the keys avoids creating unnecessary data structures
|
||||
# and happily works with both python2 and python3 where the
|
||||
# .keys() method is a dictionary_view in python3 and a list in python2.
|
||||
def _keys(self):
|
||||
yielded = set()
|
||||
for k in self.keys():
|
||||
yielded.add(k)
|
||||
yield k
|
||||
if self.parent:
|
||||
for k in self.parent._keys():
|
||||
if k not in yielded:
|
||||
yielded.add(k)
|
||||
yield k
|
||||
|
||||
def get(self, k, d=None):
|
||||
try:
|
||||
return self[k]
|
||||
except KeyError:
|
||||
return d
|
||||
|
||||
def get_color(self, k, d=None):
|
||||
"""
|
||||
Strip leading # off colors if necessary
|
||||
"""
|
||||
color = self.get(k, d)
|
||||
if hasattr(color, 'startswith') and color.startswith('#'):
|
||||
color = color[1:]
|
||||
if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
|
||||
color = ''.join(2 * c for c in color)
|
||||
return color
|
||||
|
||||
|
||||
class Element(object):
|
||||
"""
|
||||
Our base class for representing an html element along with a cascading style.
|
||||
The element is created along with a parent so that the StyleDict that we store
|
||||
can point to the parent's StyleDict.
|
||||
"""
|
||||
def __init__(self, element, parent=None):
|
||||
self.element = element
|
||||
self.number_format = None
|
||||
parent_style = parent.style_dict if parent else None
|
||||
self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
|
||||
self._style_cache = None
|
||||
|
||||
def style(self):
|
||||
"""
|
||||
Turn the css styles for this element into an openpyxl NamedStyle.
|
||||
"""
|
||||
if not self._style_cache:
|
||||
self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
|
||||
return self._style_cache
|
||||
|
||||
def get_dimension(self, dimension_key):
|
||||
"""
|
||||
Extracts the dimension from the style dict of the Element and returns it as a float.
|
||||
"""
|
||||
dimension = self.style_dict.get(dimension_key)
|
||||
if dimension:
|
||||
if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
|
||||
dimension = dimension[:-2]
|
||||
dimension = float(dimension)
|
||||
return dimension
|
||||
|
||||
|
||||
class Table(Element):
|
||||
"""
|
||||
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
|
||||
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
|
||||
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
|
||||
"""
|
||||
def __init__(self, table):
|
||||
"""
|
||||
takes an html table object (from lxml)
|
||||
"""
|
||||
super(Table, self).__init__(table)
|
||||
table_head = table.find('thead')
|
||||
self.head = TableHead(table_head, parent=self) if table_head is not None else None
|
||||
table_body = table.find('tbody')
|
||||
self.body = TableBody(table_body if table_body is not None else table, parent=self)
|
||||
|
||||
|
||||
class TableHead(Element):
|
||||
"""
|
||||
This class maps to the `<th>` element of the html table.
|
||||
"""
|
||||
def __init__(self, head, parent=None):
|
||||
super(TableHead, self).__init__(head, parent=parent)
|
||||
self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
|
||||
|
||||
|
||||
class TableBody(Element):
|
||||
"""
|
||||
This class maps to the `<tbody>` element of the html table.
|
||||
"""
|
||||
def __init__(self, body, parent=None):
|
||||
super(TableBody, self).__init__(body, parent=parent)
|
||||
self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
|
||||
|
||||
|
||||
class TableRow(Element):
|
||||
"""
|
||||
This class maps to the `<tr>` element of the html table.
|
||||
"""
|
||||
def __init__(self, tr, parent=None):
|
||||
super(TableRow, self).__init__(tr, parent=parent)
|
||||
self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
|
||||
|
||||
|
||||
def element_to_string(el):
|
||||
return _element_to_string(el).strip()
|
||||
|
||||
|
||||
def _element_to_string(el):
|
||||
string = ''
|
||||
|
||||
for x in el.iterchildren():
|
||||
string += '\n' + _element_to_string(x)
|
||||
|
||||
text = el.text.strip() if el.text else ''
|
||||
tail = el.tail.strip() if el.tail else ''
|
||||
|
||||
return text + string + '\n' + tail
|
||||
|
||||
|
||||
class TableCell(Element):
|
||||
"""
|
||||
This class maps to the `<td>` element of the html table.
|
||||
"""
|
||||
CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
|
||||
'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
|
||||
|
||||
def __init__(self, cell, parent=None):
|
||||
super(TableCell, self).__init__(cell, parent=parent)
|
||||
self.value = element_to_string(cell)
|
||||
self.number_format = self.get_number_format()
|
||||
|
||||
def data_type(self):
|
||||
cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
|
||||
if cell_types:
|
||||
if 'TYPE_FORMULA' in cell_types:
|
||||
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
|
||||
cell_type = 'TYPE_FORMULA'
|
||||
elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
|
||||
cell_type = 'TYPE_NUMERIC'
|
||||
else:
|
||||
cell_type = cell_types.pop()
|
||||
else:
|
||||
cell_type = 'TYPE_STRING'
|
||||
return getattr(cell, cell_type)
|
||||
|
||||
def get_number_format(self):
|
||||
if 'TYPE_CURRENCY' in self.element.get('class', '').split():
|
||||
return FORMAT_CURRENCY_USD_SIMPLE
|
||||
if 'TYPE_INTEGER' in self.element.get('class', '').split():
|
||||
return '#,##0'
|
||||
if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
|
||||
return FORMAT_PERCENTAGE
|
||||
if 'TYPE_DATE' in self.element.get('class', '').split():
|
||||
return FORMAT_DATE_MMDDYYYY
|
||||
if self.data_type() == cell.TYPE_NUMERIC:
|
||||
try:
|
||||
int(self.value)
|
||||
except ValueError:
|
||||
return '#,##0.##'
|
||||
else:
|
||||
return '#,##0'
|
||||
|
||||
def format(self, cell):
|
||||
cell.style = self.style()
|
||||
data_type = self.data_type()
|
||||
if data_type:
|
||||
cell.data_type = data_type
|
|
@ -0,0 +1,118 @@
|
|||
# Do imports like python3 so our package works for 2 and 3
|
||||
from __future__ import absolute_import
|
||||
|
||||
from lxml import html
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.utils import get_column_letter
|
||||
from premailer import Premailer
|
||||
from tablepyxl.style import Table
|
||||
|
||||
|
||||
def string_to_int(s):
|
||||
if s.isdigit():
|
||||
return int(s)
|
||||
return 0
|
||||
|
||||
|
||||
def get_Tables(doc):
|
||||
tree = html.fromstring(doc)
|
||||
comments = tree.xpath('//comment()')
|
||||
for comment in comments:
|
||||
comment.drop_tag()
|
||||
return [Table(table) for table in tree.xpath('//table')]
|
||||
|
||||
|
||||
def write_rows(worksheet, elem, row, column=1):
|
||||
"""
|
||||
Writes every tr child element of elem to a row in the worksheet
|
||||
returns the next row after all rows are written
|
||||
"""
|
||||
from openpyxl.cell.cell import MergedCell
|
||||
|
||||
initial_column = column
|
||||
for table_row in elem.rows:
|
||||
for table_cell in table_row.cells:
|
||||
cell = worksheet.cell(row=row, column=column)
|
||||
while isinstance(cell, MergedCell):
|
||||
column += 1
|
||||
cell = worksheet.cell(row=row, column=column)
|
||||
|
||||
colspan = string_to_int(table_cell.element.get("colspan", "1"))
|
||||
rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
|
||||
if rowspan > 1 or colspan > 1:
|
||||
worksheet.merge_cells(start_row=row, start_column=column,
|
||||
end_row=row + rowspan - 1, end_column=column + colspan - 1)
|
||||
|
||||
cell.value = table_cell.value
|
||||
table_cell.format(cell)
|
||||
min_width = table_cell.get_dimension('min-width')
|
||||
max_width = table_cell.get_dimension('max-width')
|
||||
|
||||
if colspan == 1:
|
||||
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
|
||||
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
|
||||
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
|
||||
# cell in the same column (i.e. width of A2 = width of A1)
|
||||
width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
|
||||
if max_width and width > max_width:
|
||||
width = max_width
|
||||
elif min_width and width < min_width:
|
||||
width = min_width
|
||||
worksheet.column_dimensions[get_column_letter(column)].width = width
|
||||
column += colspan
|
||||
row += 1
|
||||
column = initial_column
|
||||
return row
|
||||
|
||||
|
||||
def table_to_sheet(table, wb):
|
||||
"""
|
||||
Takes a table and workbook and writes the table to a new sheet.
|
||||
The sheet title will be the same as the table attribute name.
|
||||
"""
|
||||
ws = wb.create_sheet(title=table.element.get('name'))
|
||||
insert_table(table, ws, 1, 1)
|
||||
|
||||
|
||||
def document_to_workbook(doc, wb=None, base_url=None):
|
||||
"""
|
||||
Takes a string representation of an html document and writes one sheet for
|
||||
every table in the document.
|
||||
The workbook is returned
|
||||
"""
|
||||
if not wb:
|
||||
wb = Workbook()
|
||||
wb.remove(wb.active)
|
||||
|
||||
inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
|
||||
tables = get_Tables(inline_styles_doc)
|
||||
|
||||
for table in tables:
|
||||
table_to_sheet(table, wb)
|
||||
|
||||
return wb
|
||||
|
||||
|
||||
def document_to_xl(doc, filename, base_url=None):
|
||||
"""
|
||||
Takes a string representation of an html document and writes one sheet for
|
||||
every table in the document. The workbook is written out to a file called filename
|
||||
"""
|
||||
wb = document_to_workbook(doc, base_url=base_url)
|
||||
wb.save(filename)
|
||||
|
||||
|
||||
def insert_table(table, worksheet, column, row):
|
||||
if table.head:
|
||||
row = write_rows(worksheet, table.head, row, column)
|
||||
if table.body:
|
||||
row = write_rows(worksheet, table.body, row, column)
|
||||
|
||||
|
||||
def insert_table_at_cell(table, cell):
|
||||
"""
|
||||
Inserts a table at the location of an openpyxl Cell object.
|
||||
"""
|
||||
ws = cell.parent
|
||||
column, row = cell.column, cell.row
|
||||
insert_table(table, ws, column, row)
|
|
@ -0,0 +1,54 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
|
||||
|
||||
|
||||
def init_args():
|
||||
parser = infer_args()
|
||||
|
||||
# params for output
|
||||
parser.add_argument("--output", type=str, default='./output/table')
|
||||
# params for table structure
|
||||
parser.add_argument("--table_max_len", type=int, default=488)
|
||||
parser.add_argument("--table_model_dir", type=str)
|
||||
parser.add_argument("--table_char_type", type=str, default='en')
|
||||
parser.add_argument("--table_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = init_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def draw_result(image, result, font_path):
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
boxes, txts, scores = [], [], []
|
||||
for region in result:
|
||||
if region['type'] == 'Table':
|
||||
pass
|
||||
elif region['type'] == 'Figure':
|
||||
pass
|
||||
else:
|
||||
for box, rec_res in zip(region['res'][0], region['res'][1]):
|
||||
boxes.append(np.array(box).reshape(-1, 2))
|
||||
txts.append(rec_res[0])
|
||||
scores.append(rec_res[1])
|
||||
im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
|
||||
return im_show
|
|
@ -44,10 +44,18 @@ def main():
|
|||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
config['Architecture']["Head"]['out_channels'] = len(
|
||||
getattr(post_process_class, 'character'))
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
|
||||
best_model_dict = init_model(config, model)
|
||||
if len(best_model_dict):
|
||||
|
@ -60,7 +68,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, use_srn)
|
||||
eval_class, model_type, use_srn)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||
)
|
||||
infer_shape[-1] = 100
|
||||
|
||||
elif arch_config["model_type"] == "table":
|
||||
infer_shape = [3, 488, 488]
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
|
|
|
@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
||||
import tools.infer.benchmark_utils as benchmark_utils
|
||||
# import tools.infer.benchmark_utils as benchmark_utils
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
@ -43,7 +43,7 @@ class TextDetector(object):
|
|||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': args.det_limit_side_len,
|
||||
'limit_type': args.det_limit_type
|
||||
'limit_type': args.det_limit_type,
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
|
@ -100,8 +100,6 @@ class TextDetector(object):
|
|||
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
|
||||
args, 'det', logger)
|
||||
|
||||
self.det_times = utility.Timer()
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
"""
|
||||
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
|
||||
|
@ -158,8 +156,8 @@ class TextDetector(object):
|
|||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
self.det_times.total_time.start()
|
||||
self.det_times.preprocess_time.start()
|
||||
|
||||
st = time.time()
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
|
@ -168,16 +166,12 @@ class TextDetector(object):
|
|||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
|
||||
self.det_times.preprocess_time.end()
|
||||
self.det_times.inference_time.start()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
self.det_times.inference_time.end()
|
||||
|
||||
preds = {}
|
||||
if self.det_algorithm == "EAST":
|
||||
|
@ -193,8 +187,6 @@ class TextDetector(object):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.det_times.postprocess_time.start()
|
||||
|
||||
self.predictor.try_shrink_memory()
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
|
@ -203,10 +195,8 @@ class TextDetector(object):
|
|||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
self.det_times.postprocess_time.end()
|
||||
self.det_times.total_time.end()
|
||||
self.det_times.img_num += 1
|
||||
return dt_boxes, self.det_times.total_time.value()
|
||||
et = time.time()
|
||||
return dt_boxes, et - st
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -216,12 +206,13 @@ if __name__ == "__main__":
|
|||
count = 0
|
||||
total_time = 0
|
||||
draw_img_save = "./inference_results"
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
|
||||
# warmup 10 times
|
||||
fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
|
||||
for i in range(10):
|
||||
dt_boxes, _ = text_detector(fake_img)
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
||||
for i in range(10):
|
||||
res = text_detector(img)
|
||||
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
|
@ -239,49 +230,11 @@ if __name__ == "__main__":
|
|||
total_time += elapse
|
||||
count += 1
|
||||
|
||||
if args.benchmark:
|
||||
cm, gm, gu = utility.get_current_memory_mb(0)
|
||||
cpu_mem += cm
|
||||
gpu_mem += gm
|
||||
gpu_util += gu
|
||||
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
||||
img_name_pure = os.path.split(image_file)[-1]
|
||||
img_path = os.path.join(draw_img_save,
|
||||
"det_res_{}".format(img_name_pure))
|
||||
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
# print the information about memory and time-spent
|
||||
if args.benchmark:
|
||||
mems = {
|
||||
'cpu_rss_mb': cpu_mem / count,
|
||||
'gpu_rss_mb': gpu_mem / count,
|
||||
'gpu_util': gpu_util * 100 / count
|
||||
}
|
||||
else:
|
||||
mems = None
|
||||
logger.info("The predict time about detection module is as follows: ")
|
||||
det_time_dict = text_detector.det_times.report(average=True)
|
||||
det_model_name = args.det_model_dir
|
||||
|
||||
if args.benchmark:
|
||||
# construct log information
|
||||
model_info = {
|
||||
'model_name': args.det_model_dir.split('/')[-1],
|
||||
'precision': args.precision
|
||||
}
|
||||
data_info = {
|
||||
'batch_size': 1,
|
||||
'shape': 'dynamic_shape',
|
||||
'data_num': det_time_dict['img_num']
|
||||
}
|
||||
perf_info = {
|
||||
'preprocess_time_s': det_time_dict['preprocess_time'],
|
||||
'inference_time_s': det_time_dict['inference_time'],
|
||||
'postprocess_time_s': det_time_dict['postprocess_time'],
|
||||
'total_time_s': det_time_dict['total_time']
|
||||
}
|
||||
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||
text_detector.config, model_info, data_info, perf_info, mems)
|
||||
benchmark_log("Det")
|
||||
|
|
|
@ -257,13 +257,15 @@ def main(args):
|
|||
text_recognizer = TextRecognizer(args)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
count = 0
|
||||
|
||||
# warmup 10 times
|
||||
fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32)
|
||||
for i in range(10):
|
||||
dt_boxes, _ = text_recognizer(fake_img)
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
|
||||
for i in range(10):
|
||||
res = text_recognizer([img])
|
||||
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
count = 0
|
||||
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
|
@ -320,7 +322,8 @@ def main(args):
|
|||
'total_time_s': rec_time_dict['total_time']
|
||||
}
|
||||
benchmark_log = benchmark_utils.PaddleInferBenchmark(
|
||||
text_recognizer.config, model_info, data_info, perf_info, mems)
|
||||
text_recognizer.config, model_info, data_info, perf_info, mems,
|
||||
args.save_log_path)
|
||||
benchmark_log("Rec")
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
|
@ -24,6 +25,7 @@ import cv2
|
|||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
from PIL import Image
|
||||
import tools.infer.utility as utility
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
|
@ -38,6 +40,9 @@ logger = get_logger()
|
|||
|
||||
class TextSystem(object):
|
||||
def __init__(self, args):
|
||||
if not args.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
|
@ -55,7 +60,7 @@ class TextSystem(object):
|
|||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
|
||||
logger.info("dt_boxes num : {}, elapse : {}".format(
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
|
@ -70,11 +75,11 @@ class TextSystem(object):
|
|||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
logger.info("cls num : {}, elapse : {}".format(
|
||||
logger.debug("cls num : {}, elapse : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
logger.info("rec_res num : {}, elapse : {}".format(
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
|
@ -109,15 +114,24 @@ def sorted_boxes(dt_boxes):
|
|||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
text_sys = TextSystem(args)
|
||||
is_visualize = True
|
||||
font_path = args.vis_font_path
|
||||
drop_score = args.drop_score
|
||||
|
||||
# warm up 10 times
|
||||
if args.warmup:
|
||||
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
|
||||
for i in range(10):
|
||||
res = text_sys(img)
|
||||
|
||||
total_time = 0
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
_st = time.time()
|
||||
count = 0
|
||||
for idx, image_file in enumerate(image_file_list):
|
||||
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
|
@ -226,4 +240,18 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
||||
args = utility.parse_args()
|
||||
if args.use_mp:
|
||||
p_list = []
|
||||
total_process_num = args.total_process_num
|
||||
for process_id in range(total_process_num):
|
||||
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||
"--process_id={}".format(process_id),
|
||||
"--use_mp={}".format(False)
|
||||
]
|
||||
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||
p_list.append(p)
|
||||
for p in p_list:
|
||||
p.wait()
|
||||
else:
|
||||
main(args)
|
||||
|
|
|
@ -37,6 +37,7 @@ def init_args():
|
|||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--min_subgraph_size", type=int, default=3)
|
||||
parser.add_argument("--precision", type=str, default="fp32")
|
||||
parser.add_argument("--gpu_mem", type=int, default=500)
|
||||
|
||||
|
@ -105,7 +106,9 @@ def init_args():
|
|||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||
parser.add_argument("--cpu_threads", type=int, default=10)
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
parser.add_argument("--warmup", type=str2bool, default=True)
|
||||
|
||||
# multi-process
|
||||
parser.add_argument("--use_mp", type=str2bool, default=False)
|
||||
parser.add_argument("--total_process_num", type=int, default=1)
|
||||
parser.add_argument("--process_id", type=int, default=0)
|
||||
|
@ -113,6 +116,7 @@ def init_args():
|
|||
parser.add_argument("--benchmark", type=bool, default=False)
|
||||
parser.add_argument("--save_log_path", type=str, default="./log_output/")
|
||||
|
||||
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -198,6 +202,8 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.cls_model_dir
|
||||
elif mode == 'rec':
|
||||
model_dir = args.rec_model_dir
|
||||
elif mode == 'table':
|
||||
model_dir = args.table_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
|
@ -231,12 +237,14 @@ def create_predictor(args, mode, logger):
|
|||
config.enable_tensorrt_engine(
|
||||
precision_mode=inference.PrecisionType.Float32,
|
||||
max_batch_size=args.max_batch_size,
|
||||
min_subgraph_size=3) # skip the minmum trt subgraph
|
||||
if mode == "det" and "mobile" in model_file_path:
|
||||
min_subgraph_size=args.min_subgraph_size)
|
||||
# skip the minmum trt subgraph
|
||||
if mode == "det":
|
||||
min_input_shape = {
|
||||
"x": [1, 3, 50, 50],
|
||||
"conv2d_92.tmp_0": [1, 96, 20, 20],
|
||||
"conv2d_91.tmp_0": [1, 96, 10, 10],
|
||||
"conv2d_59.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_1.tmp_0": [1, 96, 10, 10],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
|
||||
|
@ -249,6 +257,7 @@ def create_predictor(args, mode, logger):
|
|||
"x": [1, 3, 2000, 2000],
|
||||
"conv2d_92.tmp_0": [1, 96, 400, 400],
|
||||
"conv2d_91.tmp_0": [1, 96, 200, 200],
|
||||
"conv2d_59.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_1.tmp_0": [1, 96, 200, 200],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
|
||||
|
@ -261,6 +270,7 @@ def create_predictor(args, mode, logger):
|
|||
"x": [1, 3, 640, 640],
|
||||
"conv2d_92.tmp_0": [1, 96, 160, 160],
|
||||
"conv2d_91.tmp_0": [1, 96, 80, 80],
|
||||
"conv2d_59.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_1.tmp_0": [1, 96, 80, 80],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
|
||||
|
@ -269,31 +279,6 @@ def create_predictor(args, mode, logger):
|
|||
"elementwise_add_7": [1, 56, 40, 40],
|
||||
"nearest_interp_v2_0.tmp_0": [1, 96, 40, 40]
|
||||
}
|
||||
if mode == "det" and "server" in model_file_path:
|
||||
min_input_shape = {
|
||||
"x": [1, 3, 50, 50],
|
||||
"conv2d_59.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 20, 20],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 20, 20],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 20, 20],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 20, 20]
|
||||
}
|
||||
max_input_shape = {
|
||||
"x": [1, 3, 2000, 2000],
|
||||
"conv2d_59.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 400, 400],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 400, 400],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 400, 400],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 400, 400]
|
||||
}
|
||||
opt_input_shape = {
|
||||
"x": [1, 3, 640, 640],
|
||||
"conv2d_59.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_2.tmp_0": [1, 96, 160, 160],
|
||||
"nearest_interp_v2_3.tmp_0": [1, 24, 160, 160],
|
||||
"nearest_interp_v2_4.tmp_0": [1, 24, 160, 160],
|
||||
"nearest_interp_v2_5.tmp_0": [1, 24, 160, 160]
|
||||
}
|
||||
elif mode == "rec":
|
||||
min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
|
||||
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
|
||||
|
@ -326,7 +311,10 @@ def create_predictor(args, mode, logger):
|
|||
config.disable_glog_info()
|
||||
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
if mode == 'table':
|
||||
config.delete_pass("fc_fuse_pass") # not supported for table
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
config.switch_ir_optim(True)
|
||||
|
||||
# create predictor
|
||||
predictor = inference.create_predictor(config)
|
||||
|
|
|
@ -112,4 +112,4 @@ def main():
|
|||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import paddle
|
||||
from paddle.jit import to_static
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
import cv2
|
||||
|
||||
def main(config, device, logger, vdl_writer):
|
||||
global_config = config['Global']
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
if hasattr(post_process_class, 'character'):
|
||||
config['Architecture']["Head"]['out_channels'] = len(
|
||||
getattr(post_process_class, 'character'))
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
use_padding = False
|
||||
for op in config['Eval']['dataset']['transforms']:
|
||||
op_name = list(op)[0]
|
||||
if 'Label' in op_name:
|
||||
continue
|
||||
if op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
if op_name == "ResizeTableImage":
|
||||
use_padding = True
|
||||
padding_max_len = op['ResizeTableImage']['max_len']
|
||||
transforms.append(op)
|
||||
|
||||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
model.eval()
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
with open(file, 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
res_html_code = post_result['res_html_code']
|
||||
res_loc = post_result['res_loc']
|
||||
img = cv2.imread(file)
|
||||
imgh, imgw = img.shape[0:2]
|
||||
res_loc_final = []
|
||||
for rno in range(len(res_loc[0])):
|
||||
x0, y0, x1, y1 = res_loc[0][rno]
|
||||
left = max(int(imgw * x0), 0)
|
||||
top = max(int(imgh * y0), 0)
|
||||
right = min(int(imgw * x1), imgw - 1)
|
||||
bottom = min(int(imgh * y1), imgh - 1)
|
||||
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
|
||||
res_loc_final.append([left, top, right, bottom])
|
||||
res_loc_str = json.dumps(res_loc_final)
|
||||
logger.info("result: {}, {}".format(res_html_code, res_loc_final))
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main(config, device, logger, vdl_writer)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -186,6 +186,7 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
@ -208,9 +209,9 @@ def train(config,
|
|||
lr = optimizer.get_lr()
|
||||
images = batch[0]
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table':
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
|
@ -232,8 +233,11 @@ def train(config,
|
|||
|
||||
if cal_metric_during_train: # only rec and cls need
|
||||
batch = [item.numpy() for item in batch]
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
if model_type == 'table':
|
||||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
metric = eval_class.get_metric()
|
||||
train_stats.update(metric)
|
||||
|
||||
|
@ -269,6 +273,7 @@ def train(config,
|
|||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=use_srn)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
|
@ -336,7 +341,11 @@ def train(config,
|
|||
return
|
||||
|
||||
|
||||
def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||
def eval(model,
|
||||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
|
@ -350,19 +359,19 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
if use_srn or model_type == 'table':
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
total_time += time.time() - start
|
||||
# Evaluate the results of the current batch
|
||||
eval_class(post_result, batch)
|
||||
if model_type == 'table':
|
||||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
pbar.update(1)
|
||||
total_frame += len(images)
|
||||
# Get final metric,eg. acc or hmean
|
||||
|
@ -386,7 +395,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation'
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
|
@ -35,7 +35,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
import tools.program as program
|
||||
|
||||
dist.get_world_size()
|
||||
|
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
|
|