fix error
This commit is contained in:
parent
051fe64a0d
commit
9711111270
|
@ -3,7 +3,7 @@ Global:
|
|||
epoch_num: 600
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/pg_r50_vd_tt/
|
||||
save_model_dir: ./output/pgnet_r50_vd_totaltext/
|
||||
save_epoch_step: 10
|
||||
# evaluation is run every 0 iterationss after the 1000th iteration
|
||||
eval_batch_step: [ 0, 1000 ]
|
||||
|
@ -18,7 +18,11 @@ Global:
|
|||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt
|
||||
valid_set: totaltext #two mode: totaltext valid curved words, partvgg valid non-curved words
|
||||
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
|
||||
character_dict_path: ppocr/utils/pgnet_dict.txt
|
||||
character_type: EN
|
||||
max_text_length: 50
|
||||
|
||||
Architecture:
|
||||
model_type: e2e
|
||||
|
@ -51,30 +55,26 @@ Optimizer:
|
|||
PostProcess:
|
||||
name: PGPostProcess
|
||||
score_thresh: 0.8
|
||||
cover_thresh: 0.1
|
||||
nms_thresh: 0.2
|
||||
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
|
||||
character_dict_path: ppocr/utils/pgnet_dict.txt
|
||||
main_indicator: f_score_e2e
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PGDateSet
|
||||
label_file_list: [./train_data/total_text/train/]
|
||||
name: PGDataSet
|
||||
label_file_list: [.././train_data/total_text/train/]
|
||||
ratio_list: [1.0]
|
||||
data_format: icdar
|
||||
data_format: icdar #two data format: icdar/textnet
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- PGProcessTrain:
|
||||
batch_size: 14
|
||||
batch_size: 14 # same as loader: batch_size_per_card
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
||||
loader:
|
||||
|
@ -93,10 +93,7 @@ Eval:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
|
||||
max_len: 50
|
||||
- E2EResizeForTest:
|
||||
valid_set: totaltext
|
||||
max_side_len: 768
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
|
|
|
@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型)
|
|||
- [一、训练模型转inference模型](#训练模型转inference模型)
|
||||
- [检测模型转inference模型](#检测模型转inference模型)
|
||||
- [识别模型转inference模型](#识别模型转inference模型)
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
- [端到端模型转inference模型](#端到端模型转inference模型)
|
||||
|
||||
- [二、文本检测模型推理](#文本检测模型推理)
|
||||
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
|
||||
|
@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
|
|||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
- [5. 多语言模型的推理](#多语言模型的推理)
|
||||
|
||||
- [四、方向分类模型推理](#方向识别模型推理)
|
||||
- [四、端到端模型推理](#端到端模型推理)
|
||||
- [1. PGNet端到端模型推理](#SAST文本检测模型推理)
|
||||
|
||||
- [五、方向分类模型推理](#方向识别模型推理)
|
||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||
|
||||
- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
||||
- [2. 其他模型推理](#其他模型推理)
|
||||
|
||||
|
@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
|
|||
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 分类inference模型的program文件
|
||||
```
|
||||
<a name="端到端模型转inference模型"></a>
|
||||
### 端到端模型转inference模型
|
||||
|
||||
下载端到端模型:
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
端到端模型转inference模型与检测的方式相同,如下:
|
||||
```
|
||||
# -c 后面设置训练算法的yml配置文件
|
||||
# -o 配置可选参数
|
||||
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
|
||||
# Global.load_static_weights 参数需要设置为 False。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
|
||||
```
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/e2e/
|
||||
├── inference.pdiparams # 分类inference模型的参数文件
|
||||
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 分类inference模型的program文件
|
||||
```
|
||||
|
||||
<a name="文本检测模型推理"></a>
|
||||
## 二、文本检测模型推理
|
||||
|
@ -332,8 +362,45 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
|
|||
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
|
||||
```
|
||||
|
||||
<a name="端到端模型推理"></a>
|
||||
## 四、端到端模型推理
|
||||
|
||||
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
|
||||
<a name="SAST文本检测模型推理"></a>
|
||||
### 1. PGNet端到端模型推理
|
||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
|
||||
```
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||
```
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e_pgnet_ic15/"
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_img_10_sast.jpg)
|
||||
|
||||
#### (2). 弯曲文本检测模型(Total-Text)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e_pgnet_tt
|
||||
```
|
||||
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_pgnet_tt/" --e2e_pgnet_polygon=True
|
||||
```
|
||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pg.jpg)
|
||||
|
||||
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
|
||||
|
||||
|
||||
<a name="方向分类模型推理"></a>
|
||||
## 四、方向分类模型推理
|
||||
## 五、方向分类模型推理
|
||||
|
||||
下面将介绍方向分类模型推理。
|
||||
|
||||
|
@ -358,7 +425,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
|
|||
```
|
||||
|
||||
<a name="文本检测、方向分类和文字识别串联推理"></a>
|
||||
## 五、文本检测、方向分类和文字识别串联推理
|
||||
## 六、文本检测、方向分类和文字识别串联推理
|
||||
<a name="超轻量中文OCR模型推理"></a>
|
||||
### 1. 超轻量中文OCR模型推理
|
||||
|
||||
|
|
|
@ -73,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||
else:
|
||||
use_shared_memory = True
|
||||
if mode == "Train":
|
||||
#Distribute data to multiple cards
|
||||
# Distribute data to multiple cards
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
else:
|
||||
#Distribute data to single card
|
||||
# Distribute data to single card
|
||||
batch_sampler = BatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -34,28 +34,6 @@ class ClsLabelEncode(object):
|
|||
return data
|
||||
|
||||
|
||||
class E2ELabelEncode(object):
|
||||
def __init__(self, Lexicon_Table, max_len, **kwargs):
|
||||
self.Lexicon_Table = Lexicon_Table
|
||||
self.max_len = max_len
|
||||
self.pad_num = len(self.Lexicon_Table)
|
||||
|
||||
def __call__(self, data):
|
||||
text_label_index_list, temp_text = [], []
|
||||
texts = data['strs']
|
||||
for text in texts:
|
||||
text = text.upper()
|
||||
temp_text = []
|
||||
for c_ in text:
|
||||
if c_ in self.Lexicon_Table:
|
||||
temp_text.append(self.Lexicon_Table.index(c_))
|
||||
temp_text = temp_text + [self.pad_num] * (self.max_len -
|
||||
len(temp_text))
|
||||
text_label_index_list.append(temp_text)
|
||||
data['strs'] = np.array(text_label_index_list)
|
||||
return data
|
||||
|
||||
|
||||
class DetLabelEncode(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
@ -209,6 +187,32 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class E2ELabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='EN',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(E2ELabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
texts = data['strs']
|
||||
temp_texts = []
|
||||
for text in texts:
|
||||
text = text.upper()
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
text = text + [36] * (self.max_text_len - len(text)
|
||||
) # use 36 to pad
|
||||
temp_texts.append(text)
|
||||
data['strs'] = np.array(temp_texts)
|
||||
return data
|
||||
|
||||
|
||||
class AttnLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ __all__ = ['PGProcessTrain']
|
|||
|
||||
class PGProcessTrain(object):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
batch_size=14,
|
||||
min_crop_size=24,
|
||||
min_text_size=10,
|
||||
|
@ -30,13 +31,19 @@ class PGProcessTrain(object):
|
|||
self.min_crop_size = min_crop_size
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
self.Lexicon_Table = [
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
|
||||
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
|
||||
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
|
||||
]
|
||||
self.Lexicon_Table = self.get_dict(character_dict_path)
|
||||
self.img_id = 0
|
||||
|
||||
def get_dict(self, character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
|
@ -853,7 +860,7 @@ class PGProcessTrain(object):
|
|||
for i in range(len(label_list)):
|
||||
label_list[i] = np.array(label_list[i])
|
||||
|
||||
if len(pos_list) <= 0 or len(pos_list) > 30:
|
||||
if len(pos_list) <= 0 or len(pos_list) > 30: #一张图片中最多存在30行文本
|
||||
return None
|
||||
for __ in range(30 - len(pos_list), 0, -1):
|
||||
pos_list.append(pos_list_temp)
|
||||
|
|
|
@ -19,11 +19,15 @@ from __future__ import print_function
|
|||
__all__ = ['E2EMetric']
|
||||
|
||||
from ppocr.utils.e2e_metric.Deteval import *
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import *
|
||||
|
||||
|
||||
class E2EMetric(object):
|
||||
def __init__(self, Lexicon_Table, main_indicator='f_score_e2e', **kwargs):
|
||||
self.label_list = Lexicon_Table
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
main_indicator='f_score_e2e',
|
||||
**kwargs):
|
||||
self.label_list = get_dict(character_dict_path)
|
||||
self.max_index = len(self.label_list)
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
|
|
@ -228,11 +228,11 @@ class PGHead(nn.Layer):
|
|||
f_score = self.conv1(f_score)
|
||||
f_score = F.sigmoid(f_score)
|
||||
|
||||
# f_boder
|
||||
f_boder = self.conv_f_boder1(x)
|
||||
f_boder = self.conv_f_boder2(f_boder)
|
||||
f_boder = self.conv_f_boder3(f_boder)
|
||||
f_boder = self.conv2(f_boder)
|
||||
# f_border
|
||||
f_border = self.conv_f_boder1(x)
|
||||
f_border = self.conv_f_boder2(f_border)
|
||||
f_border = self.conv_f_boder3(f_border)
|
||||
f_border = self.conv2(f_border)
|
||||
|
||||
f_char = self.conv_f_char1(x)
|
||||
f_char = self.conv_f_char2(f_char)
|
||||
|
@ -246,4 +246,9 @@ class PGHead(nn.Layer):
|
|||
f_direction = self.conv_f_direc3(f_direction)
|
||||
f_direction = self.conv4(f_direction)
|
||||
|
||||
return f_score, f_boder, f_direction, f_char
|
||||
predicts = {}
|
||||
predicts['f_score'] = f_score
|
||||
predicts['f_border'] = f_border
|
||||
predicts['f_char'] = f_char
|
||||
predicts['f_direction'] = f_direction
|
||||
return predicts
|
||||
|
|
|
@ -30,30 +30,14 @@ import paddle
|
|||
|
||||
class PGPostProcess(object):
|
||||
"""
|
||||
The post process for SAST.
|
||||
The post process for PGNet.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
score_thresh=0.5,
|
||||
nms_thresh=0.2,
|
||||
sample_pts_num=2,
|
||||
shrink_ratio_of_width=0.3,
|
||||
expand_scale=1.0,
|
||||
tcl_map_thresh=0.5,
|
||||
**kwargs):
|
||||
self.result_path = ""
|
||||
self.valid_set = 'totaltext'
|
||||
self.Lexicon_Table = [
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
|
||||
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
|
||||
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
|
||||
]
|
||||
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
|
||||
|
||||
self.Lexicon_Table = get_dict(character_dict_path)
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
self.nms_thresh = nms_thresh
|
||||
self.sample_pts_num = sample_pts_num
|
||||
self.shrink_ratio_of_width = shrink_ratio_of_width
|
||||
self.expand_scale = expand_scale
|
||||
self.tcl_map_thresh = tcl_map_thresh
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
|
@ -61,16 +45,23 @@ class PGPostProcess(object):
|
|||
self.is_python35 = True
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
p_score, p_border, p_direction, p_char = outs_dict[:4]
|
||||
p_score = p_score[0].numpy()
|
||||
p_border = p_border[0].numpy()
|
||||
p_direction = p_direction[0].numpy()
|
||||
p_char = p_char[0].numpy()
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[0]
|
||||
if self.valid_set != 'totaltext':
|
||||
is_curved = False
|
||||
p_score = outs_dict['f_score']
|
||||
p_border = outs_dict['f_border']
|
||||
p_char = outs_dict['f_char']
|
||||
p_direction = outs_dict['f_direction']
|
||||
if isinstance(p_score, paddle.Tensor):
|
||||
p_score = p_score[0].numpy()
|
||||
p_border = p_border[0].numpy()
|
||||
p_direction = p_direction[0].numpy()
|
||||
p_char = p_char[0].numpy()
|
||||
else:
|
||||
is_curved = True
|
||||
p_score = p_score[0]
|
||||
p_border = p_border[0]
|
||||
p_direction = p_direction[0]
|
||||
p_char = p_char[0]
|
||||
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[0]
|
||||
is_curved = self.valid_set == "totaltext"
|
||||
instance_yxs_list = generate_pivot_list(
|
||||
p_score,
|
||||
p_char,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -24,6 +24,17 @@ from itertools import groupby
|
|||
from skimage.morphology._skeletonize import thin
|
||||
|
||||
|
||||
def get_dict(character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
|
||||
def softmax(logits):
|
||||
"""
|
||||
logits: N x d
|
||||
|
@ -164,7 +175,6 @@ def sort_and_expand_with_direction(pos_list, f_direction):
|
|||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
|
@ -207,7 +217,6 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
|||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
|
@ -268,7 +277,6 @@ def generate_pivot_list_curved(p_score,
|
|||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
|
@ -279,7 +287,6 @@ def generate_pivot_list_curved(p_score,
|
|||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
|
||||
|
@ -290,7 +297,6 @@ def generate_pivot_list_curved(p_score,
|
|||
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
|
@ -335,11 +341,9 @@ def generate_pivot_list_horizontal(p_score,
|
|||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 5:
|
||||
continue
|
||||
|
||||
# add rule here
|
||||
main_direction = extract_main_direction(pos_list,
|
||||
f_direction) # y x
|
||||
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
|
||||
|
@ -370,7 +374,6 @@ def generate_pivot_list_horizontal(p_score,
|
|||
f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
|
@ -417,7 +420,6 @@ def generate_pivot_list(p_score,
|
|||
image_id=image_id)
|
||||
|
||||
|
||||
# for refine module
|
||||
def extract_main_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
|
@ -504,14 +506,12 @@ def generate_pivot_list_tt_inference(p_score,
|
|||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
|
|
|
@ -28,7 +28,6 @@ def resize_image(im, max_side_len=512):
|
|||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
|
@ -50,13 +49,11 @@ def resize_image(im, max_side_len=512):
|
|||
def resize_image_min(im, max_side_len=512):
|
||||
"""
|
||||
"""
|
||||
# print('--> Using resize_image_min')
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h < resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
|
@ -84,12 +81,7 @@ def resize_image_for_totaltext(im, max_side_len=512):
|
|||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
# Fix the longer side
|
||||
# if resize_h > resize_w:
|
||||
# ratio = float(max_side_len) / resize_h
|
||||
# else:
|
||||
# ratio = float(max_side_len) / resize_w
|
||||
###
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
|
@ -114,7 +106,6 @@ def point_pair2poly(point_pair_list):
|
|||
pair_info = (pair_length_list.max(), pair_length_list.min(),
|
||||
pair_length_list.mean())
|
||||
|
||||
# constract poly
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextE2e(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.e2e_algorithm = args.e2e_algorithm
|
||||
pre_process_list = [{
|
||||
'E2EResizeForTest': {
|
||||
'max_side_len': 768,
|
||||
'valid_set': 'totaltext'
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {}
|
||||
if self.e2e_algorithm == "PGNet":
|
||||
pre_process_list[0] = {
|
||||
'E2EResizeForTest': {
|
||||
'max_side_len': args.e2e_limit_side_len,
|
||||
'valid_set': 'totaltext'
|
||||
}
|
||||
}
|
||||
postprocess_params['name'] = 'PGPostProcess'
|
||||
postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
|
||||
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
|
||||
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
|
||||
self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
|
||||
if self.e2e_pgnet_polygon:
|
||||
postprocess_params["expand_scale"] = 1.2
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.2
|
||||
else:
|
||||
postprocess_params["expand_scale"] = 1.0
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.3
|
||||
else:
|
||||
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
|
||||
sys.exit(0)
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
|
||||
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
|
||||
# self.predictor.eval()
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
print(img.shape)
|
||||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
|
||||
preds = {}
|
||||
if self.e2e_algorithm == 'PGNet':
|
||||
preds['f_score'] = outputs[0]
|
||||
preds['f_border'] = outputs[1]
|
||||
preds['f_direction'] = outputs[2]
|
||||
preds['f_char'] = outputs[3]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
points, strs = post_result['points'], post_result['strs']
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
||||
elapse = time.time() - starttime
|
||||
return dt_boxes, strs, elapse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_detector = TextE2e(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
draw_img_save = "./inference_results"
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
points, strs, elapse = text_detector(img)
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
src_im = utility.draw_e2e_res(points, strs, image_file)
|
||||
img_name_pure = os.path.split(image_file)[-1]
|
||||
img_path = os.path.join(draw_img_save,
|
||||
"e2e_res_{}".format(img_name_pure))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
if count > 1:
|
||||
logger.info("Avg Time: {}".format(total_time / (count - 1)))
|
|
@ -74,6 +74,21 @@ def parse_args():
|
|||
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
|
||||
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||
|
||||
# params for e2e
|
||||
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
|
||||
parser.add_argument("--e2e_model_dir", type=str)
|
||||
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
|
||||
parser.add_argument("--e2e_limit_type", type=str, default='max')
|
||||
|
||||
# PGNet parmas
|
||||
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument(
|
||||
"--e2e_char_dict_path",
|
||||
type=str,
|
||||
default="./ppocr/utils/pgnet_dict.txt")
|
||||
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
|
||||
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=False)
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
parser.add_argument("--cls_model_dir", type=str)
|
||||
|
@ -93,8 +108,10 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.det_model_dir
|
||||
elif mode == 'cls':
|
||||
model_dir = args.cls_model_dir
|
||||
else:
|
||||
elif mode == 'rec':
|
||||
model_dir = args.rec_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
if model_dir is None:
|
||||
logger.info("not find {} model file path {}".format(mode, model_dir))
|
||||
|
@ -147,6 +164,22 @@ def create_predictor(args, mode, logger):
|
|||
return predictor, input_tensor, output_tensors
|
||||
|
||||
|
||||
def draw_e2e_res(dt_boxes, strs, img_path):
|
||||
src_im = cv2.imread(img_path)
|
||||
for box, str in zip(dt_boxes, strs):
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
cv2.putText(
|
||||
src_im,
|
||||
str,
|
||||
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
||||
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
||||
fontScale=0.7,
|
||||
color=(0, 255, 0),
|
||||
thickness=1)
|
||||
return src_im
|
||||
|
||||
|
||||
def draw_text_det_res(dt_boxes, img_path):
|
||||
src_im = cv2.imread(img_path)
|
||||
for box in dt_boxes:
|
||||
|
|
|
@ -71,7 +71,8 @@ def main():
|
|||
init_model(config, model, logger)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'])
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
|
||||
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
|
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
Loading…
Reference in New Issue