add east & sast
This commit is contained in:
parent
8a5566c974
commit
021c1132a9
|
@ -1,8 +1,7 @@
|
|||
include LICENSE.txt
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py character.py check.py
|
||||
recursive-include ppocr/data/det *.py
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include ppocr/postprocess/lanms *.*
|
||||
recursive-include tools/infer *.py
|
||||
recursive-include tools/infer *.py
|
|
@ -0,0 +1,111 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 10000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/east_mv3/
|
||||
save_epoch_step: 1000
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [4000, 5000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
save_res_path: ./output/det_east/predicts_east.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: EAST
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
Neck:
|
||||
name: EASTFPN
|
||||
model_name: small
|
||||
Head:
|
||||
name: EASTHead
|
||||
model_name: small
|
||||
|
||||
Loss:
|
||||
name: EASTLoss
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
# name: Cosine
|
||||
learning_rate: 0.001
|
||||
# warmup_epoch: 0
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: EASTPostProcess
|
||||
score_thresh: 0.8
|
||||
cover_thresh: 0.1
|
||||
nms_thresh: 0.2
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- EASTProcessTrain:
|
||||
image_shape: [512, 512]
|
||||
background_ratio: 0.125
|
||||
min_crop_side_ratio: 0.1
|
||||
min_text_size: 10
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'score_map', 'geo_map', 'training_mask'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
limit_side_len: 2400
|
||||
limit_type: max
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,110 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 10000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/east_r50_vd/
|
||||
save_epoch_step: 1000
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [4000, 5000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
save_res_path: ./output/det_east/predicts_east.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: EAST
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 50
|
||||
Neck:
|
||||
name: EASTFPN
|
||||
model_name: large
|
||||
Head:
|
||||
name: EASTHead
|
||||
model_name: large
|
||||
|
||||
Loss:
|
||||
name: EASTLoss
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
# name: Cosine
|
||||
learning_rate: 0.001
|
||||
# warmup_epoch: 0
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: EASTPostProcess
|
||||
score_thresh: 0.8
|
||||
cover_thresh: 0.1
|
||||
nms_thresh: 0.2
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- EASTProcessTrain:
|
||||
image_shape: [512, 512]
|
||||
background_ratio: 0.125
|
||||
min_crop_side_ratio: 0.1
|
||||
min_text_size: 10
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'score_map', 'geo_map', 'training_mask'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
limit_side_len: 2400
|
||||
limit_type: max
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,110 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 5000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/sast_r50_vd_ic15/
|
||||
save_epoch_step: 1000
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [4000, 5000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: SAST
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet_SAST
|
||||
layers: 50
|
||||
Neck:
|
||||
name: SASTFPN
|
||||
with_cab: True
|
||||
Head:
|
||||
name: SASTHead
|
||||
|
||||
Loss:
|
||||
name: SASTLoss
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
# name: Cosine
|
||||
learning_rate: 0.001
|
||||
# warmup_epoch: 0
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: SASTPostProcess
|
||||
score_thresh: 0.5
|
||||
sample_pts_num: 2
|
||||
nms_thresh: 0.2
|
||||
expand_scale: 1.0
|
||||
shrink_ratio_of_width: 0.3
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train_label_json.txt]
|
||||
data_ratio_list: [0.5, 0.5]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- SASTProcessTrain:
|
||||
image_shape: [512, 512]
|
||||
min_crop_side_ratio: 0.3
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'score_map', 'border_map', 'training_mask', 'tvo_map', 'tco_map'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 4
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
resize_long: 1536
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,109 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 5000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/sast_r50_vd_tt/
|
||||
save_epoch_step: 1000
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [4000, 5000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
save_res_path: ./output/sast_r50_vd_tt/predicts_sast.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: SAST
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet_SAST
|
||||
layers: 50
|
||||
Neck:
|
||||
name: SASTFPN
|
||||
with_cab: True
|
||||
Head:
|
||||
name: SASTHead
|
||||
|
||||
Loss:
|
||||
name: SASTLoss
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
# name: Cosine
|
||||
learning_rate: 0.001
|
||||
# warmup_epoch: 0
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: SASTPostProcess
|
||||
score_thresh: 0.5
|
||||
sample_pts_num: 6
|
||||
nms_thresh: 0.2
|
||||
expand_scale: 1.2
|
||||
shrink_ratio_of_width: 0.2
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
label_file_list: [./train_data/icdar2013/train_label_json.txt, ./train_data/icdar2015/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt]
|
||||
ratio_list: [0.1, 0.45, 0.3, 0.15]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- SASTProcessTrain:
|
||||
image_shape: [512, 512]
|
||||
min_crop_side_ratio: 0.3
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'score_map', 'border_map', 'training_mask', 'tvo_map', 'tco_map'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 4
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list:
|
||||
- ./train_data/total_text_icdar_14pt/test_label_json.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
resize_long: 768
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -15,7 +15,7 @@ Global:
|
|||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
character_dict_path: ppocr/utils/dict/ic15_dict.txt
|
||||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
|||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/french_dict.txt
|
||||
character_dict_path: ppocr/utils/dict/french_dict.txt
|
||||
character_type: french
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
|||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/german_dict.txt
|
||||
character_dict_path: ppocr/utils/dict/german_dict.txt
|
||||
character_type: german
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
|||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/japan_dict.txt
|
||||
character_dict_path: ppocr/utils/dict/japan_dict.txt
|
||||
character_type: japan
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
|||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/korean_dict.txt
|
||||
character_dict_path: ppocr/utils/dict/korean_dict.txt
|
||||
character_type: korean
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
|
|
|
@ -261,6 +261,61 @@ im_show.save('result.jpg')
|
|||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
|
||||
```
|
||||
|
||||
### 使用网络图片或者numpy数组作为输入
|
||||
|
||||
1. 网络图片
|
||||
|
||||
代码使用
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
命令行模式
|
||||
```bash
|
||||
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||
```
|
||||
|
||||
2. numpy数组
|
||||
仅通过代码使用时支持numpy数组作为输入
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|
@ -285,6 +340,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
|||
| max_text_length | 识别算法能识别的最大文字长度 | 25 |
|
||||
| rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt |
|
||||
| use_space_char | 是否识别空格 | TRUE |
|
||||
| drop_score | 对输出按照分数(来自于识别模型)进行过滤,低于此分数的不返回 | 0.5 |
|
||||
| use_angle_cls | 是否加载分类模型 | FALSE |
|
||||
| cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None |
|
||||
| cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" |
|
||||
|
@ -295,4 +351,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
|||
| lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch |
|
||||
| det | 前向时使用启动检测 | TRUE |
|
||||
| rec | 前向时是否启动识别 | TRUE |
|
||||
| cls | 前向时是否启动分类 | FALSE |
|
||||
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
|
||||
|
|
|
@ -271,6 +271,59 @@ im_show.save('result.jpg')
|
|||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
|
||||
```
|
||||
|
||||
### Use web images or numpy array as input
|
||||
|
||||
1. Web image
|
||||
|
||||
Use by code
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
# show result
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
Use by command line
|
||||
```bash
|
||||
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||
```
|
||||
|
||||
2. Numpy array
|
||||
Support numpy array as input only when used by code
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
# show result
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
|
||||
## Parameter Description
|
||||
|
||||
| Parameter | Description | Default value |
|
||||
|
@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
|||
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
|
||||
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
|
||||
| use_space_char | Whether to recognize spaces | TRUE |
|
||||
| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 |
|
||||
| use_angle_cls | Whether to load classification model | FALSE |
|
||||
| cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
|
||||
| cls_image_shape | image shape of classification algorithm | "3,48,192" |
|
||||
|
@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
|||
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
|
||||
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
|
||||
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
|
||||
| cls | Enable classification when `ppocr.ocr` func exec | FALSE |
|
||||
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
|
||||
|
|
238
paddleocr.py
238
paddleocr.py
|
@ -26,17 +26,50 @@ import requests
|
|||
from tqdm import tqdm
|
||||
|
||||
from tools.infer import predict_system
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = initial_logger()
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
|
||||
model_params = {
|
||||
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
|
||||
'rec':
|
||||
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
|
||||
model_urls = {
|
||||
'det':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar',
|
||||
'rec': {
|
||||
'ch': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
||||
},
|
||||
'en': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/ic15_dict.txt'
|
||||
},
|
||||
'french': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
||||
},
|
||||
'german': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
||||
},
|
||||
'korean': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
||||
},
|
||||
'japan': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
||||
}
|
||||
},
|
||||
'cls':
|
||||
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar'
|
||||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
|
@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path):
|
|||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||
logger.error("ERROR, something went wrong")
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
|
@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url):
|
|||
# using custom model
|
||||
if not os.path.exists(os.path.join(
|
||||
model_storage_directory, 'model')) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'params')):
|
||||
os.path.join(model_storage_directory, 'params')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
|
@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url):
|
|||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args(mMain=True, add_help=True):
|
||||
import argparse
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# params for prediction engine
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
if mMain:
|
||||
parser = argparse.ArgumentParser(add_help=add_help)
|
||||
# params for prediction engine
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_dir", type=str, default=None)
|
||||
parser.add_argument("--det_max_side_len", type=float, default=960)
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_dir", type=str, default=None)
|
||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
|
||||
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str, default=None)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
"--rec_char_dict_path",
|
||||
type=str,
|
||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str, default=None)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument("--rec_char_dict_path", type=str, default=None)
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||
|
||||
parser.add_argument("--det", type=str2bool, default=True)
|
||||
parser.add_argument("--rec", type=str2bool, default=True)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
# params for text classifier
|
||||
parser.add_argument("--cls_model_dir", type=str, default=None)
|
||||
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
|
||||
parser.add_argument("--lang", type=str, default='ch')
|
||||
parser.add_argument("--det", type=str2bool, default=True)
|
||||
parser.add_argument("--rec", type=str2bool, default=True)
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return argparse.Namespace(use_gpu=True,
|
||||
ir_optim=True,
|
||||
use_tensorrt=False,
|
||||
gpu_mem=8000,
|
||||
image_dir='',
|
||||
det_algorithm='DB',
|
||||
det_model_dir=None,
|
||||
det_limit_side_len=960,
|
||||
det_limit_type='max',
|
||||
det_db_thresh=0.3,
|
||||
det_db_box_thresh=0.5,
|
||||
det_db_unclip_ratio=2.0,
|
||||
det_east_score_thresh=0.8,
|
||||
det_east_cover_thresh=0.1,
|
||||
det_east_nms_thresh=0.2,
|
||||
rec_algorithm='CRNN',
|
||||
rec_model_dir=None,
|
||||
rec_image_shape="3, 32, 320",
|
||||
rec_char_type='ch',
|
||||
rec_batch_num=30,
|
||||
max_text_length=25,
|
||||
rec_char_dict_path=None,
|
||||
use_space_char=True,
|
||||
drop_score=0.5,
|
||||
cls_model_dir=None,
|
||||
cls_image_shape="3, 48, 192",
|
||||
label_list=['0', '180'],
|
||||
cls_batch_num=30,
|
||||
cls_thresh=0.9,
|
||||
enable_mkldnn=False,
|
||||
use_zero_copy_run=False,
|
||||
use_pdserving=False,
|
||||
lang='ch',
|
||||
det=True,
|
||||
rec=True,
|
||||
use_angle_cls=False
|
||||
)
|
||||
|
||||
|
||||
class PaddleOCR(predict_system.TextSystem):
|
||||
|
@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
args:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
postprocess_params = parse_args()
|
||||
postprocess_params = parse_args(mMain=False, add_help=False)
|
||||
postprocess_params.__dict__.update(**kwargs)
|
||||
self.use_angle_cls = postprocess_params.use_angle_cls
|
||||
lang = postprocess_params.lang
|
||||
assert lang in model_urls[
|
||||
'rec'], 'param lang must in {}, but got {}'.format(
|
||||
model_urls['rec'].keys(), lang)
|
||||
if postprocess_params.rec_char_dict_path is None:
|
||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
'dict_path']
|
||||
|
||||
# init model dir
|
||||
if postprocess_params.det_model_dir is None:
|
||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
|
||||
if postprocess_params.rec_model_dir is None:
|
||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec')
|
||||
postprocess_params.rec_model_dir = os.path.join(
|
||||
BASE_DIR, 'rec/{}'.format(lang))
|
||||
if postprocess_params.cls_model_dir is None:
|
||||
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
print(postprocess_params)
|
||||
# download model
|
||||
maybe_download(postprocess_params.det_model_dir, model_params['det'])
|
||||
maybe_download(postprocess_params.rec_model_dir, model_params['rec'])
|
||||
maybe_download(postprocess_params.det_model_dir, model_urls['det'])
|
||||
maybe_download(postprocess_params.rec_model_dir,
|
||||
model_urls['rec'][lang]['url'])
|
||||
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
||||
|
||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
|
@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
# init det_model and rec_model
|
||||
super().__init__(postprocess_params)
|
||||
|
||||
def ocr(self, img, det=True, rec=True):
|
||||
def ocr(self, img, det=True, rec=True, cls=False):
|
||||
"""
|
||||
ocr with paddleocr
|
||||
args:
|
||||
|
@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
rec: use text recognition or not, if false, only det will be exec. default is True
|
||||
"""
|
||||
assert isinstance(img, (np.ndarray, list, str))
|
||||
if isinstance(img, list) and det == True:
|
||||
logger.error('When input a list of images, det must be false')
|
||||
exit(0)
|
||||
|
||||
self.use_angle_cls = cls
|
||||
if isinstance(img, str):
|
||||
# download net image
|
||||
if img.startswith('http'):
|
||||
download_with_progressbar(img, 'tmp.jpg')
|
||||
img = 'tmp.jpg'
|
||||
image_file = img
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
|
@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
if det and rec:
|
||||
dt_boxes, rec_res = self.__call__(img)
|
||||
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||
|
@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
else:
|
||||
if not isinstance(img, list):
|
||||
img = [img]
|
||||
if self.use_angle_cls:
|
||||
img, cls_res, elapse = self.text_classifier(img)
|
||||
if not rec:
|
||||
return cls_res
|
||||
rec_res, elapse = self.text_recognizer(img)
|
||||
return rec_res
|
||||
|
||||
|
||||
def main():
|
||||
# for com
|
||||
args = parse_args()
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
if image_dir.startswith('http'):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
if len(image_file_list) == 0:
|
||||
logger.error('no images find in {}'.format(args.image_dir))
|
||||
return
|
||||
ocr_engine = PaddleOCR()
|
||||
|
||||
ocr_engine = PaddleOCR(**(args.__dict__))
|
||||
for img_path in image_file_list:
|
||||
print(img_path)
|
||||
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
|
||||
for line in result:
|
||||
print(line)
|
||||
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||
result = ocr_engine.ocr(img_path,
|
||||
det=args.det,
|
||||
rec=args.rec,
|
||||
cls=args.use_angle_cls)
|
||||
if result is not None:
|
||||
for line in result:
|
||||
logger.info(line)
|
||||
|
|
|
@ -26,6 +26,9 @@ from .randaugment import RandAugment
|
|||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
from .east_process import *
|
||||
from .sast_process import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
|
|
|
@ -0,0 +1,439 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
__all__ = ['EASTProcessTrain']
|
||||
|
||||
|
||||
class EASTProcessTrain(object):
|
||||
def __init__(self,
|
||||
image_shape = [512, 512],
|
||||
background_ratio = 0.125,
|
||||
min_crop_side_ratio = 0.1,
|
||||
min_text_size = 10,
|
||||
**kwargs):
|
||||
self.input_size = image_shape[1]
|
||||
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
||||
self.background_ratio = background_ratio
|
||||
self.min_crop_side_ratio = min_crop_side_ratio
|
||||
self.min_text_size = min_text_size
|
||||
|
||||
def preprocess(self, im):
|
||||
input_size = self.input_size
|
||||
im_shape = im.shape
|
||||
im_size_min = np.min(im_shape[0:2])
|
||||
im_size_max = np.max(im_shape[0:2])
|
||||
im_scale = float(input_size) / float(im_size_max)
|
||||
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
# im = im[:, :, ::-1].astype(np.float32)
|
||||
im = im / 255
|
||||
im -= img_mean
|
||||
im /= img_std
|
||||
new_h, new_w, _ = im.shape
|
||||
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
|
||||
im_padded[:new_h, :new_w, :] = im
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
im_padded = im_padded[np.newaxis, :]
|
||||
return im_padded, im_scale
|
||||
|
||||
def rotate_im_poly(self, im, text_polys):
|
||||
"""
|
||||
rotate image with 90 / 180 / 270 degre
|
||||
"""
|
||||
im_w, im_h = im.shape[1], im.shape[0]
|
||||
dst_im = im.copy()
|
||||
dst_polys = []
|
||||
rand_degree_ratio = np.random.rand()
|
||||
rand_degree_cnt = 1
|
||||
if 0.333 < rand_degree_ratio < 0.666:
|
||||
rand_degree_cnt = 2
|
||||
elif rand_degree_ratio > 0.666:
|
||||
rand_degree_cnt = 3
|
||||
for i in range(rand_degree_cnt):
|
||||
dst_im = np.rot90(dst_im)
|
||||
rot_degree = -90 * rand_degree_cnt
|
||||
rot_angle = rot_degree * math.pi / 180.0
|
||||
n_poly = text_polys.shape[0]
|
||||
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||
for i in range(n_poly):
|
||||
wordBB = text_polys[i]
|
||||
poly = []
|
||||
for j in range(4):
|
||||
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||
dx = math.cos(rot_angle) * (sx - cx)\
|
||||
- math.sin(rot_angle) * (sy - cy) + ncx
|
||||
dy = math.sin(rot_angle) * (sx - cx)\
|
||||
+ math.cos(rot_angle) * (sy - cy) + ncy
|
||||
poly.append([dx, dy])
|
||||
dst_polys.append(poly)
|
||||
dst_polys = np.array(dst_polys, dtype=np.float32)
|
||||
return dst_im, dst_polys
|
||||
|
||||
def polygon_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, img_height, img_width):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
h, w = img_height, img_width
|
||||
if polys.shape[0] == 0:
|
||||
return polys
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
p_area = self.polygon_area(poly)
|
||||
#invalid poly
|
||||
if abs(p_area) < 1:
|
||||
continue
|
||||
if p_area > 0:
|
||||
#'poly in wrong direction'
|
||||
if not tag:
|
||||
tag = True #reversed cases should be ignore
|
||||
poly = poly[(0, 3, 2, 1), :]
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
return np.array(validated_polys), np.array(validated_tags)
|
||||
|
||||
def draw_img_polys(self, img, polys):
|
||||
if len(img.shape) == 4:
|
||||
img = np.squeeze(img, axis=0)
|
||||
if img.shape[0] == 3:
|
||||
img = img.transpose((1, 2, 0))
|
||||
img[:, :, 2] += 123.68
|
||||
img[:, :, 1] += 116.78
|
||||
img[:, :, 0] += 103.94
|
||||
cv2.imwrite("tmp.jpg", img)
|
||||
img = cv2.imread("tmp.jpg")
|
||||
for box in polys:
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
||||
import random
|
||||
ino = random.randint(0, 100)
|
||||
cv2.imwrite("tmp_%d.jpg" % ino, img)
|
||||
return
|
||||
|
||||
def shrink_poly(self, poly, r):
|
||||
"""
|
||||
fit a poly inside the origin poly, maybe bugs here...
|
||||
used for generate the score map
|
||||
:param poly: the text poly
|
||||
:param r: r in the paper
|
||||
:return: the shrinked poly
|
||||
"""
|
||||
# shrink ratio
|
||||
R = 0.3
|
||||
# find the longer pair
|
||||
dist0 = np.linalg.norm(poly[0] - poly[1])
|
||||
dist1 = np.linalg.norm(poly[2] - poly[3])
|
||||
dist2 = np.linalg.norm(poly[0] - poly[3])
|
||||
dist3 = np.linalg.norm(poly[1] - poly[2])
|
||||
if dist0 + dist1 > dist2 + dist3:
|
||||
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
|
||||
## p0, p1
|
||||
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
||||
(poly[1][0] - poly[0][0]))
|
||||
poly[0][0] += R * r[0] * np.cos(theta)
|
||||
poly[0][1] += R * r[0] * np.sin(theta)
|
||||
poly[1][0] -= R * r[1] * np.cos(theta)
|
||||
poly[1][1] -= R * r[1] * np.sin(theta)
|
||||
## p2, p3
|
||||
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
||||
(poly[2][0] - poly[3][0]))
|
||||
poly[3][0] += R * r[3] * np.cos(theta)
|
||||
poly[3][1] += R * r[3] * np.sin(theta)
|
||||
poly[2][0] -= R * r[2] * np.cos(theta)
|
||||
poly[2][1] -= R * r[2] * np.sin(theta)
|
||||
## p0, p3
|
||||
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
||||
(poly[3][1] - poly[0][1]))
|
||||
poly[0][0] += R * r[0] * np.sin(theta)
|
||||
poly[0][1] += R * r[0] * np.cos(theta)
|
||||
poly[3][0] -= R * r[3] * np.sin(theta)
|
||||
poly[3][1] -= R * r[3] * np.cos(theta)
|
||||
## p1, p2
|
||||
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
||||
(poly[2][1] - poly[1][1]))
|
||||
poly[1][0] += R * r[1] * np.sin(theta)
|
||||
poly[1][1] += R * r[1] * np.cos(theta)
|
||||
poly[2][0] -= R * r[2] * np.sin(theta)
|
||||
poly[2][1] -= R * r[2] * np.cos(theta)
|
||||
else:
|
||||
## p0, p3
|
||||
# print poly
|
||||
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
||||
(poly[3][1] - poly[0][1]))
|
||||
poly[0][0] += R * r[0] * np.sin(theta)
|
||||
poly[0][1] += R * r[0] * np.cos(theta)
|
||||
poly[3][0] -= R * r[3] * np.sin(theta)
|
||||
poly[3][1] -= R * r[3] * np.cos(theta)
|
||||
## p1, p2
|
||||
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
||||
(poly[2][1] - poly[1][1]))
|
||||
poly[1][0] += R * r[1] * np.sin(theta)
|
||||
poly[1][1] += R * r[1] * np.cos(theta)
|
||||
poly[2][0] -= R * r[2] * np.sin(theta)
|
||||
poly[2][1] -= R * r[2] * np.cos(theta)
|
||||
## p0, p1
|
||||
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
||||
(poly[1][0] - poly[0][0]))
|
||||
poly[0][0] += R * r[0] * np.cos(theta)
|
||||
poly[0][1] += R * r[0] * np.sin(theta)
|
||||
poly[1][0] -= R * r[1] * np.cos(theta)
|
||||
poly[1][1] -= R * r[1] * np.sin(theta)
|
||||
## p2, p3
|
||||
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
||||
(poly[2][0] - poly[3][0]))
|
||||
poly[3][0] += R * r[3] * np.cos(theta)
|
||||
poly[3][1] += R * r[3] * np.sin(theta)
|
||||
poly[2][0] -= R * r[2] * np.cos(theta)
|
||||
poly[2][1] -= R * r[2] * np.sin(theta)
|
||||
return poly
|
||||
|
||||
def generate_quad(self, im_size, polys, tags):
|
||||
"""
|
||||
Generate quadrangle.
|
||||
"""
|
||||
h, w = im_size
|
||||
poly_mask = np.zeros((h, w), dtype=np.uint8)
|
||||
score_map = np.zeros((h, w), dtype=np.uint8)
|
||||
# (x1, y1, ..., x4, y4, short_edge_norm)
|
||||
geo_map = np.zeros((h, w, 9), dtype=np.float32)
|
||||
# mask used during traning, to ignore some hard areas
|
||||
training_mask = np.ones((h, w), dtype=np.uint8)
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
r = [None, None, None, None]
|
||||
for i in range(4):
|
||||
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
|
||||
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
|
||||
r[i] = min(dist1, dist2)
|
||||
# score map
|
||||
shrinked_poly = self.shrink_poly(
|
||||
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
|
||||
cv2.fillPoly(score_map, shrinked_poly, 1)
|
||||
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
|
||||
# if the poly is too small, then ignore it during training
|
||||
poly_h = min(
|
||||
np.linalg.norm(poly[0] - poly[3]),
|
||||
np.linalg.norm(poly[1] - poly[2]))
|
||||
poly_w = min(
|
||||
np.linalg.norm(poly[0] - poly[1]),
|
||||
np.linalg.norm(poly[2] - poly[3]))
|
||||
if min(poly_h, poly_w) < self.min_text_size:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
|
||||
if tag:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
|
||||
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
|
||||
# geo map.
|
||||
y_in_poly = xy_in_poly[:, 0]
|
||||
x_in_poly = xy_in_poly[:, 1]
|
||||
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
|
||||
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
|
||||
for pno in range(4):
|
||||
geo_channel_beg = pno * 2
|
||||
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
|
||||
x_in_poly - poly[pno, 0]
|
||||
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
|
||||
y_in_poly - poly[pno, 1]
|
||||
geo_map[y_in_poly, x_in_poly, 8] = \
|
||||
1.0 / max(min(poly_h, poly_w), 1.0)
|
||||
return score_map, geo_map, training_mask
|
||||
|
||||
def crop_area(self,
|
||||
im,
|
||||
polys,
|
||||
tags,
|
||||
crop_background=False,
|
||||
max_tries=50):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys:
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries:
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags
|
||||
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if xmax - xmin < self.min_crop_side_ratio * w or \
|
||||
ymax - ymin < self.min_crop_side_ratio * h:
|
||||
# area too small
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
|
||||
& (polys[:, :, 0] <= xmax)\
|
||||
& (polys[:, :, 1] >= ymin)\
|
||||
& (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = []
|
||||
tags = []
|
||||
return im, polys, tags
|
||||
else:
|
||||
continue
|
||||
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags
|
||||
return im, polys, tags
|
||||
|
||||
def crop_background_infor(self, im, text_polys, text_tags):
|
||||
im, text_polys, text_tags = self.crop_area(
|
||||
im, text_polys, text_tags, crop_background=True)
|
||||
|
||||
if len(text_polys) > 0:
|
||||
return None
|
||||
# pad and resize image
|
||||
input_size = self.input_size
|
||||
im, ratio = self.preprocess(im)
|
||||
score_map = np.zeros((input_size, input_size), dtype=np.float32)
|
||||
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
|
||||
training_mask = np.ones((input_size, input_size), dtype=np.float32)
|
||||
return im, score_map, geo_map, training_mask
|
||||
|
||||
def crop_foreground_infor(self, im, text_polys, text_tags):
|
||||
im, text_polys, text_tags = self.crop_area(
|
||||
im, text_polys, text_tags, crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
# pad and resize image
|
||||
input_size = self.input_size
|
||||
im, ratio = self.preprocess(im)
|
||||
text_polys[:, :, 0] *= ratio
|
||||
text_polys[:, :, 1] *= ratio
|
||||
_, _, new_h, new_w = im.shape
|
||||
# print(im.shape)
|
||||
# self.draw_img_polys(im, text_polys)
|
||||
score_map, geo_map, training_mask = self.generate_quad(
|
||||
(new_h, new_w), text_polys, text_tags)
|
||||
return im, score_map, geo_map, training_mask
|
||||
|
||||
def __call__(self, data):
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['ignore_tags']
|
||||
if im is None:
|
||||
return None
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
#add rotate cases
|
||||
if np.random.rand() < 0.5:
|
||||
im, text_polys = self.rotate_im_poly(im, text_polys)
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags = self.check_and_validate_polys(text_polys,
|
||||
text_tags, h, w)
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
# random scale this image
|
||||
rd_scale = np.random.choice(self.random_scale)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
if np.random.rand() < self.background_ratio:
|
||||
outs = self.crop_background_infor(im, text_polys, text_tags)
|
||||
else:
|
||||
outs = self.crop_foreground_infor(im, text_polys, text_tags)
|
||||
|
||||
if outs is None:
|
||||
return None
|
||||
im, score_map, geo_map, training_mask = outs
|
||||
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
|
||||
geo_map = np.swapaxes(geo_map, 1, 2)
|
||||
geo_map = np.swapaxes(geo_map, 1, 0)
|
||||
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
|
||||
training_mask = training_mask[np.newaxis, ::4, ::4]
|
||||
training_mask = training_mask.astype(np.float32)
|
||||
|
||||
data['image'] = im[0]
|
||||
data['score_map'] = score_map
|
||||
data['geo_map'] = geo_map
|
||||
data['training_mask'] = training_mask
|
||||
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
|
||||
return data
|
|
@ -52,6 +52,7 @@ class DetLabelEncode(object):
|
|||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
boxes = self.expand_points_num(boxes)
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
|
||||
|
@ -70,6 +71,17 @@ class DetLabelEncode(object):
|
|||
rect[3] = pts[np.argmax(diff)]
|
||||
return rect
|
||||
|
||||
def expand_points_num(self, boxes):
|
||||
max_points_num = 0
|
||||
for box in boxes:
|
||||
if len(box) > max_points_num:
|
||||
max_points_num = len(box)
|
||||
ex_boxes = []
|
||||
for box in boxes:
|
||||
ex_box = box + [box[-1]] * (max_points_num - len(box))
|
||||
ex_boxes.append(ex_box)
|
||||
return ex_boxes
|
||||
|
||||
|
||||
class BaseRecLabelEncode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
@ -79,7 +91,9 @@ class BaseRecLabelEncode(object):
|
|||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False):
|
||||
support_character_type = ['ch', 'en', 'en_sensitive']
|
||||
support_character_type = [
|
||||
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, self.character_str)
|
||||
|
||||
|
@ -87,7 +101,7 @@ class BaseRecLabelEncode(object):
|
|||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type == "ch":
|
||||
elif character_type in ["ch", "french", "german", "japan", "korean"]:
|
||||
self.character_str = ""
|
||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
|
|
|
@ -120,26 +120,37 @@ class DetResizeForTest(object):
|
|||
if 'limit_side_len' in kwargs:
|
||||
self.limit_side_len = kwargs['limit_side_len']
|
||||
self.limit_type = kwargs.get('limit_type', 'min')
|
||||
if 'resize_long' in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get('resize_long', 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = 'min'
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
|
||||
if self.resize_type == 0:
|
||||
img, shape = self.resize_image_type0(img)
|
||||
# img, shape = self.resize_image_type0(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
img, shape = self.resize_image_type1(img)
|
||||
# img, shape = self.resize_image_type1(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
data['image'] = img
|
||||
data['shape'] = shape
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
return img, np.array([ori_h, ori_w])
|
||||
# return img, np.array([ori_h, ori_w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
|
@ -182,4 +193,31 @@ class DetResizeForTest(object):
|
|||
except:
|
||||
print(img.shape, resize_w, resize_h)
|
||||
sys.exit(0)
|
||||
return img, np.array([h, w])
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
# return img, np.array([h, w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
|
|
@ -0,0 +1,689 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
__all__ = ['SASTProcessTrain']
|
||||
|
||||
|
||||
class SASTProcessTrain(object):
|
||||
def __init__(self,
|
||||
image_shape = [512, 512],
|
||||
min_crop_size = 24,
|
||||
min_crop_side_ratio = 0.3,
|
||||
min_text_size = 10,
|
||||
max_text_size = 512,
|
||||
**kwargs):
|
||||
self.input_size = image_shape[1]
|
||||
self.min_crop_size = min_crop_size
|
||||
self.min_crop_side_ratio = min_crop_side_ratio
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [
|
||||
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
|
||||
]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if True:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
|
||||
|
||||
def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys:
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w: maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h: maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
|
||||
# ymax - ymin < ARGS.min_crop_side_ratio * h:
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
# area too small
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
else:
|
||||
continue
|
||||
im = im[ymin: ymax + 1, xmin: xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags
|
||||
|
||||
return im, polys, tags, hv_tags
|
||||
|
||||
def generate_direction_map(self, poly_quads, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
||||
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)]))
|
||||
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label)
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_label(self, hw, polys, tags, ds_ratio,
|
||||
tcl_ratio=0.3, shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros((h, w,), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones((h, w,), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32)
|
||||
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
# continue
|
||||
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
# stcl map
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
# generate tcl map
|
||||
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
|
||||
# generate tbo map
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map)
|
||||
return score_map, tbo_map, training_mask
|
||||
|
||||
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
|
||||
"""
|
||||
Generate tcl map, tvo map and tbo map.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
poly_mask = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
tvo_map = np.ones((9, h, w), dtype=np.float32)
|
||||
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
|
||||
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
|
||||
|
||||
# tco map
|
||||
tco_map = np.ones((3, h, w), dtype=np.float32)
|
||||
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
|
||||
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
|
||||
|
||||
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
|
||||
|
||||
for poly, poly_tag in zip(polys, tags):
|
||||
|
||||
if poly_tag == True:
|
||||
continue
|
||||
|
||||
# adjust point order for vertical poly
|
||||
poly = self.adjust_point(poly)
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
# generate tcl map and text, 128 * 128
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
|
||||
# generate poly_tv_xy_map
|
||||
for idx in range(4):
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||
|
||||
# generate poly_tc_xy_map
|
||||
for idx in range(2):
|
||||
cv2.fillPoly(poly_tc_xy_map[idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx]))
|
||||
|
||||
# generate poly_short_edge_map
|
||||
cv2.fillPoly(poly_short_edge_map,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||
|
||||
# generate poly_mask and training_mask
|
||||
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1)
|
||||
|
||||
tvo_map *= poly_mask
|
||||
tvo_map[:8] -= poly_tv_xy_map
|
||||
tvo_map[-1] /= poly_short_edge_map
|
||||
tvo_map = tvo_map.transpose((1, 2, 0))
|
||||
|
||||
tco_map *= poly_mask
|
||||
tco_map[:2] -= poly_tc_xy_map
|
||||
tco_map[-1] /= poly_short_edge_map
|
||||
tco_map = tco_map.transpose((1, 2, 0))
|
||||
|
||||
return tvo_map, tco_map
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
#print("line1", line1)
|
||||
#print("line2", line2)
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def __call__(self, data):
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['ignore_tags']
|
||||
if im is None:
|
||||
return None
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
#set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
#no background
|
||||
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
|
||||
text_polys, text_tags, hv_tags, crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
#resize image
|
||||
std_ratio = float(self.input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
#add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks/2)*2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
#add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
#add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < self.input_size * 0.5:
|
||||
return None
|
||||
|
||||
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = self.input_size - new_h
|
||||
del_w = self.input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size),
|
||||
text_polys, text_tags, 0.25)
|
||||
|
||||
# SAST head
|
||||
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25)
|
||||
# print("test--------tvo_map shape:", tvo_map.shape)
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = im_padded[::-1, :, :]
|
||||
data['score_map'] = score_map[np.newaxis, :, :]
|
||||
data['border_map'] = border_map.transpose((2, 0, 1))
|
||||
data['training_mask'] = training_mask[np.newaxis, :, :]
|
||||
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
|
||||
data['tco_map'] = tco_map.transpose((2, 0, 1))
|
||||
return data
|
|
@ -18,6 +18,8 @@ import copy
|
|||
def build_loss(config):
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
@ -25,7 +27,7 @@ def build_loss(config):
|
|||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss']
|
||||
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .det_basic_loss import DiceLoss
|
||||
|
||||
|
||||
class EASTLoss(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(EASTLoss, self).__init__()
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
l_score, l_geo, l_mask = labels[1:]
|
||||
f_score = predicts['f_score']
|
||||
f_geo = predicts['f_geo']
|
||||
|
||||
dice_loss = self.dice_loss(f_score, l_score, l_mask)
|
||||
|
||||
#smoooth_l1_loss
|
||||
channels = 8
|
||||
l_geo_split = paddle.split(
|
||||
l_geo, num_or_sections=channels + 1, axis=1)
|
||||
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
|
||||
smooth_l1 = 0
|
||||
for i in range(0, channels):
|
||||
geo_diff = l_geo_split[i] - f_geo_split[i]
|
||||
abs_geo_diff = paddle.abs(geo_diff)
|
||||
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
|
||||
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
|
||||
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
|
||||
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
|
||||
out_loss = l_geo_split[-1] / channels * in_loss * l_score
|
||||
smooth_l1 += out_loss
|
||||
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
|
||||
|
||||
dice_loss = dice_loss * 0.01
|
||||
total_loss = dice_loss + smooth_l1_loss
|
||||
losses = {"loss":total_loss, \
|
||||
"dice_loss":dice_loss,\
|
||||
"smooth_l1_loss":smooth_l1_loss}
|
||||
return losses
|
|
@ -0,0 +1,121 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .det_basic_loss import DiceLoss
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SASTLoss(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(SASTLoss, self).__init__()
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
"""
|
||||
tcl_pos: N x 128 x 3
|
||||
tcl_mask: N x 128 x 1
|
||||
tcl_label: N x X list or LoDTensor
|
||||
"""
|
||||
|
||||
f_score = predicts['f_score']
|
||||
f_border = predicts['f_border']
|
||||
f_tvo = predicts['f_tvo']
|
||||
f_tco = predicts['f_tco']
|
||||
|
||||
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
|
||||
|
||||
#score_loss
|
||||
intersection = paddle.sum(f_score * l_score * l_mask)
|
||||
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
|
||||
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
|
||||
|
||||
#border loss
|
||||
l_border_split, l_border_norm = paddle.split(l_border, num_or_sections=[4, 1], axis=1)
|
||||
f_border_split = f_border
|
||||
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
|
||||
l_border_norm_split = paddle.expand(x=l_border_norm, shape=border_ex_shape)
|
||||
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
|
||||
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
|
||||
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = paddle.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = paddle.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
||||
|
||||
#tvo_loss
|
||||
l_tvo_split, l_tvo_norm = paddle.split(l_tvo, num_or_sections=[8, 1], axis=1)
|
||||
f_tvo_split = f_tvo
|
||||
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
|
||||
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
|
||||
l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
|
||||
l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
|
||||
#
|
||||
tvo_geo_diff = l_tvo_split - f_tvo_split
|
||||
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
|
||||
tvo_sign = abs_tvo_geo_diff < 1.0
|
||||
tvo_sign = paddle.cast(tvo_sign, dtype='float32')
|
||||
tvo_sign.stop_gradient = True
|
||||
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
|
||||
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
|
||||
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
|
||||
tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
|
||||
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
|
||||
|
||||
#tco_loss
|
||||
l_tco_split, l_tco_norm = paddle.split(l_tco, num_or_sections=[2, 1], axis=1)
|
||||
f_tco_split = f_tco
|
||||
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
|
||||
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
|
||||
l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
|
||||
l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
|
||||
|
||||
tco_geo_diff = l_tco_split - f_tco_split
|
||||
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
|
||||
tco_sign = abs_tco_geo_diff < 1.0
|
||||
tco_sign = paddle.cast(tco_sign, dtype='float32')
|
||||
tco_sign.stop_gradient = True
|
||||
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
|
||||
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
|
||||
tco_out_loss = l_tco_norm_split * tco_in_loss
|
||||
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
|
||||
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
|
||||
|
||||
|
||||
# total loss
|
||||
tvo_lw, tco_lw = 1.5, 1.5
|
||||
score_lw, border_lw = 1.0, 1.0
|
||||
total_loss = score_loss * score_lw + border_loss * border_lw + \
|
||||
tvo_loss * tvo_lw + tco_loss * tco_lw
|
||||
|
||||
losses = {'loss':total_loss, "score_loss":score_loss,\
|
||||
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
|
||||
return losses
|
|
@ -19,6 +19,7 @@ def build_backbone(config, model_type):
|
|||
if model_type == 'det':
|
||||
from .det_mobilenet_v3 import MobileNetV3
|
||||
from .det_resnet_vd import ResNet
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
||||
elif model_type == 'rec' or model_type == 'cls':
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet_SAST"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_vd_mode:
|
||||
inputs = self._pool2d_avg(inputs)
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet_SAST(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet_SAST, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
depth = [3, 4, 6, 3, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
# num_channels = [64, 256, 512,
|
||||
# 1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
# num_filters = [64, 128, 256, 512]
|
||||
num_channels = [64, 256, 512,
|
||||
1024, 2048] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_2")
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_3")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = [3, 64]
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
out = [inputs]
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
out.append(y)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
|
@ -18,13 +18,15 @@ __all__ = ['build_head']
|
|||
def build_head(config):
|
||||
# det head
|
||||
from .det_db_head import DBHead
|
||||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = ['DBHead', 'CTCHead', 'ClsHead']
|
||||
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class EASTHead(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, in_channels, model_name, **kwargs):
|
||||
super(EASTHead, self).__init__()
|
||||
self.model_name = model_name
|
||||
if self.model_name == "large":
|
||||
num_outputs = [128, 64, 1, 8]
|
||||
else:
|
||||
num_outputs = [64, 32, 1, 8]
|
||||
|
||||
self.det_conv1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=num_outputs[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="det_head1")
|
||||
self.det_conv2 = ConvBNLayer(
|
||||
in_channels=num_outputs[0],
|
||||
out_channels=num_outputs[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="det_head2")
|
||||
self.score_conv = ConvBNLayer(
|
||||
in_channels=num_outputs[1],
|
||||
out_channels=num_outputs[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name="f_score")
|
||||
self.geo_conv = ConvBNLayer(
|
||||
in_channels=num_outputs[1],
|
||||
out_channels=num_outputs[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name="f_geo")
|
||||
|
||||
def forward(self, x):
|
||||
f_det = self.det_conv1(x)
|
||||
f_det = self.det_conv2(f_det)
|
||||
f_score = self.score_conv(f_det)
|
||||
f_score = F.sigmoid(f_score)
|
||||
f_geo = self.geo_conv(f_det)
|
||||
f_geo = (F.sigmoid(f_geo) - 0.5) * 2 * 800
|
||||
|
||||
pred = {'f_score': f_score, 'f_geo': f_geo}
|
||||
return pred
|
|
@ -0,0 +1,128 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class SAST_Header1(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(SAST_Header1, self).__init__()
|
||||
out_channels = [64, 64, 128]
|
||||
self.score_conv = nn.Sequential(
|
||||
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_score1'),
|
||||
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_score2'),
|
||||
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_score3'),
|
||||
ConvBNLayer(out_channels[2], 1, 3, 1, act=None, name='f_score4')
|
||||
)
|
||||
self.border_conv = nn.Sequential(
|
||||
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_border1'),
|
||||
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_border2'),
|
||||
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_border3'),
|
||||
ConvBNLayer(out_channels[2], 4, 3, 1, act=None, name='f_border4')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
f_score = self.score_conv(x)
|
||||
f_score = F.sigmoid(f_score)
|
||||
f_border = self.border_conv(x)
|
||||
return f_score, f_border
|
||||
|
||||
|
||||
class SAST_Header2(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(SAST_Header2, self).__init__()
|
||||
out_channels = [64, 64, 128]
|
||||
self.tvo_conv = nn.Sequential(
|
||||
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tvo1'),
|
||||
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tvo2'),
|
||||
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tvo3'),
|
||||
ConvBNLayer(out_channels[2], 8, 3, 1, act=None, name='f_tvo4')
|
||||
)
|
||||
self.tco_conv = nn.Sequential(
|
||||
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tco1'),
|
||||
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tco2'),
|
||||
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tco3'),
|
||||
ConvBNLayer(out_channels[2], 2, 3, 1, act=None, name='f_tco4')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
f_tvo = self.tvo_conv(x)
|
||||
f_tco = self.tco_conv(x)
|
||||
return f_tvo, f_tco
|
||||
|
||||
|
||||
class SASTHead(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(SASTHead, self).__init__()
|
||||
|
||||
self.head1 = SAST_Header1(in_channels)
|
||||
self.head2 = SAST_Header2(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
f_score, f_border = self.head1(x)
|
||||
f_tvo, f_tco = self.head2(x)
|
||||
|
||||
predicts = {}
|
||||
predicts['f_score'] = f_score
|
||||
predicts['f_border'] = f_border
|
||||
predicts['f_tvo'] = f_tvo
|
||||
predicts['f_tco'] = f_tco
|
||||
return predicts
|
|
@ -16,8 +16,10 @@ __all__ = ['build_neck']
|
|||
|
||||
def build_neck(config):
|
||||
from .db_fpn import DBFPN
|
||||
from .east_fpn import EASTFPN
|
||||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
support_dict = ['DBFPN', 'SequenceEncoder']
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(DeConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.deconv = nn.Conv2DTranspose(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.deconv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class EASTFPN(nn.Layer):
|
||||
def __init__(self, in_channels, model_name, **kwargs):
|
||||
super(EASTFPN, self).__init__()
|
||||
self.model_name = model_name
|
||||
if self.model_name == "large":
|
||||
self.out_channels = 128
|
||||
else:
|
||||
self.out_channels = 64
|
||||
self.in_channels = in_channels[::-1]
|
||||
self.h1_conv = ConvBNLayer(
|
||||
in_channels=self.out_channels+self.in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_h_1")
|
||||
self.h2_conv = ConvBNLayer(
|
||||
in_channels=self.out_channels+self.in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_h_2")
|
||||
self.h3_conv = ConvBNLayer(
|
||||
in_channels=self.out_channels+self.in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_h_3")
|
||||
self.g0_deconv = DeConvBNLayer(
|
||||
in_channels=self.in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_g_0")
|
||||
self.g1_deconv = DeConvBNLayer(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_g_1")
|
||||
self.g2_deconv = DeConvBNLayer(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_g_2")
|
||||
self.g3_conv = ConvBNLayer(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
if_act=True,
|
||||
act='relu',
|
||||
name="unet_g_3")
|
||||
|
||||
def forward(self, x):
|
||||
f = x[::-1]
|
||||
|
||||
h = f[0]
|
||||
g = self.g0_deconv(h)
|
||||
h = paddle.concat([g, f[1]], axis=1)
|
||||
h = self.h1_conv(h)
|
||||
g = self.g1_deconv(h)
|
||||
h = paddle.concat([g, f[2]], axis=1)
|
||||
h = self.h2_conv(h)
|
||||
g = self.g2_deconv(h)
|
||||
h = paddle.concat([g, f[3]], axis=1)
|
||||
h = self.h3_conv(h)
|
||||
g = self.g3_conv(h)
|
||||
|
||||
return g
|
|
@ -0,0 +1,284 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class DeConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(DeConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.deconv = nn.Conv2DTranspose(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.deconv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class FPN_Up_Fusion(nn.Layer):
|
||||
def __init__(self, in_channels):
|
||||
super(FPN_Up_Fusion, self).__init__()
|
||||
in_channels = in_channels[::-1]
|
||||
out_channels = [256, 256, 192, 192, 128]
|
||||
|
||||
self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 1, 1, act=None, name='fpn_up_h0')
|
||||
self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 1, 1, act=None, name='fpn_up_h1')
|
||||
self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 1, 1, act=None, name='fpn_up_h2')
|
||||
self.h3_conv = ConvBNLayer(in_channels[3], out_channels[3], 1, 1, act=None, name='fpn_up_h3')
|
||||
self.h4_conv = ConvBNLayer(in_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_h4')
|
||||
|
||||
self.g0_conv = DeConvBNLayer(out_channels[0], out_channels[1], 4, 2, act=None, name='fpn_up_g0')
|
||||
|
||||
self.g1_conv = nn.Sequential(
|
||||
ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_up_g1_1'),
|
||||
DeConvBNLayer(out_channels[1], out_channels[2], 4, 2, act=None, name='fpn_up_g1_2')
|
||||
)
|
||||
self.g2_conv = nn.Sequential(
|
||||
ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_up_g2_1'),
|
||||
DeConvBNLayer(out_channels[2], out_channels[3], 4, 2, act=None, name='fpn_up_g2_2')
|
||||
)
|
||||
self.g3_conv = nn.Sequential(
|
||||
ConvBNLayer(out_channels[3], out_channels[3], 3, 1, act='relu', name='fpn_up_g3_1'),
|
||||
DeConvBNLayer(out_channels[3], out_channels[4], 4, 2, act=None, name='fpn_up_g3_2')
|
||||
)
|
||||
|
||||
self.g4_conv = nn.Sequential(
|
||||
ConvBNLayer(out_channels[4], out_channels[4], 3, 1, act='relu', name='fpn_up_fusion_1'),
|
||||
ConvBNLayer(out_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_fusion_2')
|
||||
)
|
||||
|
||||
def _add_relu(self, x1, x2):
|
||||
x = paddle.add(x=x1, y=x2)
|
||||
x = F.relu(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
f = x[2:][::-1]
|
||||
h0 = self.h0_conv(f[0])
|
||||
h1 = self.h1_conv(f[1])
|
||||
h2 = self.h2_conv(f[2])
|
||||
h3 = self.h3_conv(f[3])
|
||||
h4 = self.h4_conv(f[4])
|
||||
|
||||
g0 = self.g0_conv(h0)
|
||||
g1 = self._add_relu(g0, h1)
|
||||
g1 = self.g1_conv(g1)
|
||||
g2 = self.g2_conv(self._add_relu(g1, h2))
|
||||
g3 = self.g3_conv(self._add_relu(g2, h3))
|
||||
g4 = self.g4_conv(self._add_relu(g3, h4))
|
||||
|
||||
return g4
|
||||
|
||||
|
||||
class FPN_Down_Fusion(nn.Layer):
|
||||
def __init__(self, in_channels):
|
||||
super(FPN_Down_Fusion, self).__init__()
|
||||
out_channels = [32, 64, 128]
|
||||
|
||||
self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 3, 1, act=None, name='fpn_down_h0')
|
||||
self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 3, 1, act=None, name='fpn_down_h1')
|
||||
self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 3, 1, act=None, name='fpn_down_h2')
|
||||
|
||||
self.g0_conv = ConvBNLayer(out_channels[0], out_channels[1], 3, 2, act=None, name='fpn_down_g0')
|
||||
|
||||
self.g1_conv = nn.Sequential(
|
||||
ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_down_g1_1'),
|
||||
ConvBNLayer(out_channels[1], out_channels[2], 3, 2, act=None, name='fpn_down_g1_2')
|
||||
)
|
||||
|
||||
self.g2_conv = nn.Sequential(
|
||||
ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_down_fusion_1'),
|
||||
ConvBNLayer(out_channels[2], out_channels[2], 1, 1, act=None, name='fpn_down_fusion_2')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
f = x[:3]
|
||||
h0 = self.h0_conv(f[0])
|
||||
h1 = self.h1_conv(f[1])
|
||||
h2 = self.h2_conv(f[2])
|
||||
g0 = self.g0_conv(h0)
|
||||
g1 = paddle.add(x=g0, y=h1)
|
||||
g1 = F.relu(g1)
|
||||
g1 = self.g1_conv(g1)
|
||||
g2 = paddle.add(x=g1, y=h2)
|
||||
g2 = F.relu(g2)
|
||||
g2 = self.g2_conv(g2)
|
||||
return g2
|
||||
|
||||
|
||||
class Cross_Attention(nn.Layer):
|
||||
def __init__(self, in_channels):
|
||||
super(Cross_Attention, self).__init__()
|
||||
self.theta_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_theta')
|
||||
self.phi_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_phi')
|
||||
self.g_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_g')
|
||||
|
||||
self.fh_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_weight')
|
||||
self.fh_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_sc')
|
||||
|
||||
self.fv_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_weight')
|
||||
self.fv_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_sc')
|
||||
|
||||
self.f_attn_conv = ConvBNLayer(in_channels * 2, in_channels, 1, 1, act='relu', name='f_attn')
|
||||
|
||||
def _cal_fweight(self, f, shape):
|
||||
f_theta, f_phi, f_g = f
|
||||
#flatten
|
||||
f_theta = paddle.transpose(f_theta, [0, 2, 3, 1])
|
||||
f_theta = paddle.reshape(f_theta, [shape[0] * shape[1], shape[2], 128])
|
||||
f_phi = paddle.transpose(f_phi, [0, 2, 3, 1])
|
||||
f_phi = paddle.reshape(f_phi, [shape[0] * shape[1], shape[2], 128])
|
||||
f_g = paddle.transpose(f_g, [0, 2, 3, 1])
|
||||
f_g = paddle.reshape(f_g, [shape[0] * shape[1], shape[2], 128])
|
||||
#correlation
|
||||
f_attn = paddle.matmul(f_theta, paddle.transpose(f_phi, [0, 2, 1]))
|
||||
#scale
|
||||
f_attn = f_attn / (128**0.5)
|
||||
f_attn = F.softmax(f_attn)
|
||||
#weighted sum
|
||||
f_weight = paddle.matmul(f_attn, f_g)
|
||||
f_weight = paddle.reshape(
|
||||
f_weight, [shape[0], shape[1], shape[2], 128])
|
||||
return f_weight
|
||||
|
||||
def forward(self, f_common):
|
||||
f_shape = paddle.shape(f_common)
|
||||
# print('f_shape: ', f_shape)
|
||||
|
||||
f_theta = self.theta_conv(f_common)
|
||||
f_phi = self.phi_conv(f_common)
|
||||
f_g = self.g_conv(f_common)
|
||||
|
||||
######## horizon ########
|
||||
fh_weight = self._cal_fweight([f_theta, f_phi, f_g],
|
||||
[f_shape[0], f_shape[2], f_shape[3]])
|
||||
fh_weight = paddle.transpose(fh_weight, [0, 3, 1, 2])
|
||||
fh_weight = self.fh_weight_conv(fh_weight)
|
||||
#short cut
|
||||
fh_sc = self.fh_sc_conv(f_common)
|
||||
f_h = F.relu(fh_weight + fh_sc)
|
||||
|
||||
######## vertical ########
|
||||
fv_theta = paddle.transpose(f_theta, [0, 1, 3, 2])
|
||||
fv_phi = paddle.transpose(f_phi, [0, 1, 3, 2])
|
||||
fv_g = paddle.transpose(f_g, [0, 1, 3, 2])
|
||||
fv_weight = self._cal_fweight([fv_theta, fv_phi, fv_g],
|
||||
[f_shape[0], f_shape[3], f_shape[2]])
|
||||
fv_weight = paddle.transpose(fv_weight, [0, 3, 2, 1])
|
||||
fv_weight = self.fv_weight_conv(fv_weight)
|
||||
#short cut
|
||||
fv_sc = self.fv_sc_conv(f_common)
|
||||
f_v = F.relu(fv_weight + fv_sc)
|
||||
|
||||
######## merge ########
|
||||
f_attn = paddle.concat([f_h, f_v], axis=1)
|
||||
f_attn = self.f_attn_conv(f_attn)
|
||||
return f_attn
|
||||
|
||||
|
||||
class SASTFPN(nn.Layer):
|
||||
def __init__(self, in_channels, with_cab=False, **kwargs):
|
||||
super(SASTFPN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.with_cab = with_cab
|
||||
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
|
||||
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
|
||||
self.out_channels = 128
|
||||
self.cross_attention = Cross_Attention(self.out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
#down fpn
|
||||
f_down = self.FPN_Down_Fusion(x)
|
||||
|
||||
#up fpn
|
||||
f_up = self.FPN_Up_Fusion(x)
|
||||
|
||||
#fusion
|
||||
f_common = paddle.add(x=f_down, y=f_up)
|
||||
f_common = F.relu(f_common)
|
||||
|
||||
if self.with_cab:
|
||||
# print('enhence f_common with CAB.')
|
||||
f_common = self.cross_attention(f_common)
|
||||
|
||||
return f_common
|
|
@ -24,11 +24,13 @@ __all__ = ['build_post_process']
|
|||
|
||||
def build_post_process(config, global_config=None):
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from .locality_aware_nms import nms_locality
|
||||
import cv2
|
||||
|
||||
import os
|
||||
import sys
|
||||
# __dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
# sys.path.append(__dir__)
|
||||
# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
class EASTPostProcess(object):
|
||||
"""
|
||||
The post process for EAST.
|
||||
"""
|
||||
def __init__(self,
|
||||
score_thresh=0.8,
|
||||
cover_thresh=0.1,
|
||||
nms_thresh=0.2,
|
||||
**kwargs):
|
||||
|
||||
self.score_thresh = score_thresh
|
||||
self.cover_thresh = cover_thresh
|
||||
self.nms_thresh = nms_thresh
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def restore_rectangle_quad(self, origin, geometry):
|
||||
"""
|
||||
Restore rectangle from quadrangle.
|
||||
"""
|
||||
# quad
|
||||
origin_concat = np.concatenate(
|
||||
(origin, origin, origin, origin), axis=1) # (n, 8)
|
||||
pred_quads = origin_concat - geometry
|
||||
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
|
||||
return pred_quads
|
||||
|
||||
def detect(self,
|
||||
score_map,
|
||||
geo_map,
|
||||
score_thresh=0.8,
|
||||
cover_thresh=0.1,
|
||||
nms_thresh=0.2):
|
||||
"""
|
||||
restore text boxes from score map and geo map
|
||||
"""
|
||||
score_map = score_map[0]
|
||||
geo_map = np.swapaxes(geo_map, 1, 0)
|
||||
geo_map = np.swapaxes(geo_map, 1, 2)
|
||||
# filter the score map
|
||||
xy_text = np.argwhere(score_map > score_thresh)
|
||||
if len(xy_text) == 0:
|
||||
return []
|
||||
# sort the text boxes via the y axis
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 0])]
|
||||
#restore quad proposals
|
||||
text_box_restored = self.restore_rectangle_quad(
|
||||
xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
|
||||
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
|
||||
boxes[:, :8] = text_box_restored.reshape((-1, 8))
|
||||
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
|
||||
if self.is_python35:
|
||||
import lanms
|
||||
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
|
||||
else:
|
||||
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
|
||||
if boxes.shape[0] == 0:
|
||||
return []
|
||||
# Here we filter some low score boxes by the average score map,
|
||||
# this is different from the orginal paper.
|
||||
for i, box in enumerate(boxes):
|
||||
mask = np.zeros_like(score_map, dtype=np.uint8)
|
||||
cv2.fillPoly(mask, box[:8].reshape(
|
||||
(-1, 4, 2)).astype(np.int32) // 4, 1)
|
||||
boxes[i, 8] = cv2.mean(score_map, mask)[0]
|
||||
boxes = boxes[boxes[:, 8] > cover_thresh]
|
||||
return boxes
|
||||
|
||||
def sort_poly(self, p):
|
||||
"""
|
||||
Sort polygons.
|
||||
"""
|
||||
min_axis = np.argmin(np.sum(p, axis=1))
|
||||
p = p[[min_axis, (min_axis + 1) % 4,\
|
||||
(min_axis + 2) % 4, (min_axis + 3) % 4]]
|
||||
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
|
||||
return p
|
||||
else:
|
||||
return p[[0, 3, 2, 1]]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
score_list = outs_dict['f_score']
|
||||
geo_list = outs_dict['f_geo']
|
||||
img_num = len(shape_list)
|
||||
dt_boxes_list = []
|
||||
for ino in range(img_num):
|
||||
score = score_list[ino].numpy()
|
||||
geo = geo_list[ino].numpy()
|
||||
boxes = self.detect(
|
||||
score_map=score,
|
||||
geo_map=geo,
|
||||
score_thresh=self.score_thresh,
|
||||
cover_thresh=self.cover_thresh,
|
||||
nms_thresh=self.nms_thresh)
|
||||
boxes_norm = []
|
||||
if len(boxes) > 0:
|
||||
h, w = score.shape[1:]
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||
boxes = boxes[:, :8].reshape((-1, 4, 2))
|
||||
boxes[:, :, 0] /= ratio_w
|
||||
boxes[:, :, 1] /= ratio_h
|
||||
for i_box, box in enumerate(boxes):
|
||||
box = self.sort_poly(box.astype(np.int32))
|
||||
if np.linalg.norm(box[0] - box[1]) < 5 \
|
||||
or np.linalg.norm(box[3] - box[0]) < 5:
|
||||
continue
|
||||
boxes_norm.append(box)
|
||||
dt_boxes_list.append({'points': np.array(boxes_norm)})
|
||||
return dt_boxes_list
|
|
@ -0,0 +1,199 @@
|
|||
"""
|
||||
Locality aware nms.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
|
||||
def intersection(g, p):
|
||||
"""
|
||||
Intersection.
|
||||
"""
|
||||
g = Polygon(g[:8].reshape((4, 2)))
|
||||
p = Polygon(p[:8].reshape((4, 2)))
|
||||
g = g.buffer(0)
|
||||
p = p.buffer(0)
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
inter = Polygon(g).intersection(Polygon(p)).area
|
||||
union = g.area + p.area - inter
|
||||
if union == 0:
|
||||
return 0
|
||||
else:
|
||||
return inter / union
|
||||
|
||||
|
||||
def intersection_iog(g, p):
|
||||
"""
|
||||
Intersection_iog.
|
||||
"""
|
||||
g = Polygon(g[:8].reshape((4, 2)))
|
||||
p = Polygon(p[:8].reshape((4, 2)))
|
||||
if not g.is_valid or not p.is_valid:
|
||||
return 0
|
||||
inter = Polygon(g).intersection(Polygon(p)).area
|
||||
#union = g.area + p.area - inter
|
||||
union = p.area
|
||||
if union == 0:
|
||||
print("p_area is very small")
|
||||
return 0
|
||||
else:
|
||||
return inter / union
|
||||
|
||||
|
||||
def weighted_merge(g, p):
|
||||
"""
|
||||
Weighted merge.
|
||||
"""
|
||||
g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
|
||||
g[8] = (g[8] + p[8])
|
||||
return g
|
||||
|
||||
|
||||
def standard_nms(S, thres):
|
||||
"""
|
||||
Standard nms.
|
||||
"""
|
||||
order = np.argsort(S[:, 8])[::-1]
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return S[keep]
|
||||
|
||||
|
||||
def standard_nms_inds(S, thres):
|
||||
"""
|
||||
Standard nms, retun inds.
|
||||
"""
|
||||
order = np.argsort(S[:, 8])[::-1]
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def nms(S, thres):
|
||||
"""
|
||||
nms.
|
||||
"""
|
||||
order = np.argsort(S[:, 8])[::-1]
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
|
||||
"""
|
||||
soft_nms
|
||||
:para boxes_in, N x 9 (coords + score)
|
||||
:para threshould, eliminate cases min score(0.001)
|
||||
:para Nt_thres, iou_threshi
|
||||
:para sigma, gaussian weght
|
||||
:method, linear or gaussian
|
||||
"""
|
||||
boxes = boxes_in.copy()
|
||||
N = boxes.shape[0]
|
||||
if N is None or N < 1:
|
||||
return np.array([])
|
||||
pos, maxpos = 0, 0
|
||||
weight = 0.0
|
||||
inds = np.arange(N)
|
||||
tbox, sbox = boxes[0].copy(), boxes[0].copy()
|
||||
for i in range(N):
|
||||
maxscore = boxes[i, 8]
|
||||
maxpos = i
|
||||
tbox = boxes[i].copy()
|
||||
ti = inds[i]
|
||||
pos = i + 1
|
||||
#get max box
|
||||
while pos < N:
|
||||
if maxscore < boxes[pos, 8]:
|
||||
maxscore = boxes[pos, 8]
|
||||
maxpos = pos
|
||||
pos = pos + 1
|
||||
#add max box as a detection
|
||||
boxes[i, :] = boxes[maxpos, :]
|
||||
inds[i] = inds[maxpos]
|
||||
#swap
|
||||
boxes[maxpos, :] = tbox
|
||||
inds[maxpos] = ti
|
||||
tbox = boxes[i].copy()
|
||||
pos = i + 1
|
||||
#NMS iteration
|
||||
while pos < N:
|
||||
sbox = boxes[pos].copy()
|
||||
ts_iou_val = intersection(tbox, sbox)
|
||||
if ts_iou_val > 0:
|
||||
if method == 1:
|
||||
if ts_iou_val > Nt_thres:
|
||||
weight = 1 - ts_iou_val
|
||||
else:
|
||||
weight = 1
|
||||
elif method == 2:
|
||||
weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
|
||||
else:
|
||||
if ts_iou_val > Nt_thres:
|
||||
weight = 0
|
||||
else:
|
||||
weight = 1
|
||||
boxes[pos, 8] = weight * boxes[pos, 8]
|
||||
#if box score falls below thresold, discard the box by
|
||||
#swaping last box update N
|
||||
if boxes[pos, 8] < threshold:
|
||||
boxes[pos, :] = boxes[N - 1, :]
|
||||
inds[pos] = inds[N - 1]
|
||||
N = N - 1
|
||||
pos = pos - 1
|
||||
pos = pos + 1
|
||||
|
||||
return boxes[:N]
|
||||
|
||||
|
||||
def nms_locality(polys, thres=0.3):
|
||||
"""
|
||||
locality aware nms of EAST
|
||||
:param polys: a N*9 numpy array. first 8 coordinates, then prob
|
||||
:return: boxes after nms
|
||||
"""
|
||||
S = []
|
||||
p = None
|
||||
for g in polys:
|
||||
if p is not None and intersection(g, p) > thres:
|
||||
p = weighted_merge(g, p)
|
||||
else:
|
||||
if p is not None:
|
||||
S.append(p)
|
||||
p = g
|
||||
if p is not None:
|
||||
S.append(p)
|
||||
|
||||
if len(S) == 0:
|
||||
return np.array([])
|
||||
return standard_nms(np.array(S), thres)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 343,350,448,135,474,143,369,359
|
||||
print(
|
||||
Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]]))
|
||||
.area)
|
|
@ -23,14 +23,16 @@ class BaseRecLabelDecode(object):
|
|||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False):
|
||||
support_character_type = ['ch', 'en', 'en_sensitive']
|
||||
support_character_type = [
|
||||
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, self.character_str)
|
||||
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type == "ch":
|
||||
elif character_type in ["ch", "french", "german", "japan", "korean"]:
|
||||
self.character_str = ""
|
||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
|
@ -150,4 +152,4 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
return idx
|
||||
|
|
|
@ -0,0 +1,295 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
import numpy as np
|
||||
from .locality_aware_nms import nms_locality
|
||||
# import lanms
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
||||
class SASTPostProcess(object):
|
||||
"""
|
||||
The post process for SAST.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
score_thresh=0.5,
|
||||
nms_thresh=0.2,
|
||||
sample_pts_num=2,
|
||||
shrink_ratio_of_width=0.3,
|
||||
expand_scale=1.0,
|
||||
tcl_map_thresh=0.5,
|
||||
**kwargs):
|
||||
|
||||
self.score_thresh = score_thresh
|
||||
self.nms_thresh = nms_thresh
|
||||
self.sample_pts_num = sample_pts_num
|
||||
self.shrink_ratio_of_width = shrink_ratio_of_width
|
||||
self.expand_scale = expand_scale
|
||||
self.tcl_map_thresh = tcl_map_thresh
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def point_pair2poly(self, point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
# constract poly
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2)
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
||||
"""Restore quad."""
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
|
||||
# Sort the text boxes via the y axis
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
||||
|
||||
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
scores = scores[:, np.newaxis]
|
||||
|
||||
# Restore
|
||||
point_num = int(tvo_map.shape[-1] / 2)
|
||||
assert point_num == 4
|
||||
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
quads = xy_text_tile - tvo_map
|
||||
|
||||
return scores, quads, xy_text
|
||||
|
||||
def quad_area(self, quad):
|
||||
"""
|
||||
compute area of a quad.
|
||||
"""
|
||||
edge = [
|
||||
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
|
||||
]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def nms(self, dets):
|
||||
if self.is_python35:
|
||||
import lanms
|
||||
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
|
||||
else:
|
||||
dets = nms_locality(dets, self.nms_thresh)
|
||||
return dets
|
||||
|
||||
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
|
||||
"""
|
||||
Cluster pixels in tcl_map based on quads.
|
||||
"""
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
||||
if instance_count == 1:
|
||||
return instance_count, instance_label_map
|
||||
|
||||
# predict text center
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
n = xy_text.shape[0]
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
pred_tc = xy_text - tco
|
||||
|
||||
# get gt text center
|
||||
m = quads.shape[0]
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
|
||||
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
||||
return instance_count, instance_label_map
|
||||
|
||||
def estimate_sample_pts_num(self, quad, xy_text):
|
||||
"""
|
||||
Estimate sample points number.
|
||||
"""
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
|
||||
dense_sample_pts_num = max(2, int(ew))
|
||||
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
|
||||
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
||||
|
||||
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
||||
return sample_pts_num
|
||||
|
||||
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
|
||||
"""
|
||||
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
||||
"""
|
||||
# restore quad
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
||||
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
||||
dets = self.nms(dets)
|
||||
if dets.shape[0] == 0:
|
||||
return []
|
||||
quads = dets[:, :-1].reshape(-1, 4, 2)
|
||||
|
||||
# Compute quad area
|
||||
quad_areas = []
|
||||
for quad in quads:
|
||||
quad_areas.append(-self.quad_area(quad))
|
||||
|
||||
# instance segmentation
|
||||
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
|
||||
# restore single poly with tcl instance.
|
||||
poly_list = []
|
||||
for instance_idx in range(1, instance_count):
|
||||
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
|
||||
quad = quads[instance_idx - 1]
|
||||
q_area = quad_areas[instance_idx - 1]
|
||||
if q_area < 5:
|
||||
continue
|
||||
|
||||
#
|
||||
len1 = float(np.linalg.norm(quad[0] -quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] -quad[2]))
|
||||
min_len = min(len1, len2)
|
||||
if min_len < 3:
|
||||
continue
|
||||
|
||||
# filter small CC
|
||||
if xy_text.shape[0] <= 0:
|
||||
continue
|
||||
|
||||
# filter low confidence instance
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
continue
|
||||
|
||||
# sort xy_text
|
||||
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / \
|
||||
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
||||
xy_text = xy_text[np.argsort(proj_value)]
|
||||
|
||||
# Sample pts in tcl map
|
||||
if self.sample_pts_num == 0:
|
||||
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
||||
else:
|
||||
sample_pts_num = self.sample_pts_num
|
||||
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
|
||||
point_pair_list = []
|
||||
for x, y in xy_center_line:
|
||||
# get corresponding offset
|
||||
offset = tbo_map[y, x, :].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
# ndarry: (x, 2), expand poly along width
|
||||
detected_poly = self.point_pair2poly(point_pair_list)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
poly_list.append(detected_poly)
|
||||
|
||||
return poly_list
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
score_list = outs_dict['f_score']
|
||||
border_list = outs_dict['f_border']
|
||||
tvo_list = outs_dict['f_tvo']
|
||||
tco_list = outs_dict['f_tco']
|
||||
|
||||
img_num = len(shape_list)
|
||||
poly_lists = []
|
||||
for ino in range(img_num):
|
||||
p_score = score_list[ino].transpose((1,2,0)).numpy()
|
||||
p_border = border_list[ino].transpose((1,2,0)).numpy()
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0)).numpy()
|
||||
p_tco = tco_list[ino].transpose((1,2,0)).numpy()
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||
|
||||
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
|
||||
poly_lists.append({'points': np.array(poly_list)})
|
||||
|
||||
return poly_lists
|
||||
|
2
setup.py
2
setup.py
|
@ -32,7 +32,7 @@ setup(
|
|||
package_dir={'paddleocr': ''},
|
||||
include_package_data=True,
|
||||
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
|
||||
version='0.0.3',
|
||||
version='2.0',
|
||||
install_requires=requirements,
|
||||
license='Apache License 2.0',
|
||||
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
|||
from ppocr.utils.logging import get_logger
|
||||
from tools.infer.utility import draw_ocr_box_txt
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextSystem(object):
|
||||
def __init__(self, args):
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
if self.use_angle_cls:
|
||||
self.text_classifier = predict_cls.TextClassifier(args)
|
||||
|
||||
|
@ -81,7 +85,8 @@ class TextSystem(object):
|
|||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
|
||||
logger.info("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
img_crop_list = []
|
||||
|
@ -99,9 +104,16 @@ class TextSystem(object):
|
|||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
||||
logger.info("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||
return dt_boxes, rec_res
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_reuslt in zip(dt_boxes, rec_res):
|
||||
text, score = rec_reuslt
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_reuslt)
|
||||
return filter_boxes, filter_rec_res
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
|
@ -117,8 +129,8 @@ def sorted_boxes(dt_boxes):
|
|||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
|
@ -143,12 +155,8 @@ def main(args):
|
|||
elapse = time.time() - starttime
|
||||
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
||||
|
||||
dt_num = len(dt_boxes)
|
||||
for dno in range(dt_num):
|
||||
text, score = rec_res[dno]
|
||||
if score >= drop_score:
|
||||
text_str = "%s, %.3f" % (text, score)
|
||||
logger.info(text_str)
|
||||
for text, score in rec_res:
|
||||
logger.info("{}, {:.3f}".format(text, score))
|
||||
|
||||
if is_visualize:
|
||||
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||
|
@ -174,5 +182,4 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger()
|
||||
main(utility.parse_args())
|
||||
main(utility.parse_args())
|
Loading…
Reference in New Issue