add east & sast
This commit is contained in:
parent
8a5566c974
commit
021c1132a9
|
@ -1,8 +1,7 @@
|
||||||
include LICENSE.txt
|
include LICENSE.txt
|
||||||
include README.md
|
include README.md
|
||||||
|
|
||||||
recursive-include ppocr/utils *.txt utility.py character.py check.py
|
recursive-include ppocr/utils *.txt utility.py logging.py
|
||||||
recursive-include ppocr/data/det *.py
|
recursive-include ppocr/data/ *.py
|
||||||
recursive-include ppocr/postprocess *.py
|
recursive-include ppocr/postprocess *.py
|
||||||
recursive-include ppocr/postprocess/lanms *.*
|
|
||||||
recursive-include tools/infer *.py
|
recursive-include tools/infer *.py
|
|
@ -0,0 +1,111 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 10000
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/east_mv3/
|
||||||
|
save_epoch_step: 1000
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [4000, 5000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
load_static_weights: True
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img:
|
||||||
|
save_res_path: ./output/det_east/predicts_east.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: det
|
||||||
|
algorithm: EAST
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
Neck:
|
||||||
|
name: EASTFPN
|
||||||
|
model_name: small
|
||||||
|
Head:
|
||||||
|
name: EASTHead
|
||||||
|
model_name: small
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: EASTLoss
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
# name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
# warmup_epoch: 0
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: EASTPostProcess
|
||||||
|
score_thresh: 0.8
|
||||||
|
cover_thresh: 0.1
|
||||||
|
nms_thresh: 0.2
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||||
|
ratio_list: [1.0]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- EASTProcessTrain:
|
||||||
|
image_shape: [512, 512]
|
||||||
|
background_ratio: 0.125
|
||||||
|
min_crop_side_ratio: 0.1
|
||||||
|
min_text_size: 10
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'score_map', 'geo_map', 'training_mask'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 16
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
limit_side_len: 2400
|
||||||
|
limit_type: max
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -0,0 +1,110 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 10000
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/east_r50_vd/
|
||||||
|
save_epoch_step: 1000
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [4000, 5000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
load_static_weights: True
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img:
|
||||||
|
save_res_path: ./output/det_east/predicts_east.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: det
|
||||||
|
algorithm: EAST
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 50
|
||||||
|
Neck:
|
||||||
|
name: EASTFPN
|
||||||
|
model_name: large
|
||||||
|
Head:
|
||||||
|
name: EASTHead
|
||||||
|
model_name: large
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: EASTLoss
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
# name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
# warmup_epoch: 0
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: EASTPostProcess
|
||||||
|
score_thresh: 0.8
|
||||||
|
cover_thresh: 0.1
|
||||||
|
nms_thresh: 0.2
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||||
|
ratio_list: [1.0]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- EASTProcessTrain:
|
||||||
|
image_shape: [512, 512]
|
||||||
|
background_ratio: 0.125
|
||||||
|
min_crop_side_ratio: 0.1
|
||||||
|
min_text_size: 10
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'score_map', 'geo_map', 'training_mask'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 8
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
limit_side_len: 2400
|
||||||
|
limit_type: max
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -0,0 +1,110 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 5000
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/sast_r50_vd_ic15/
|
||||||
|
save_epoch_step: 1000
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [4000, 5000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
load_static_weights: True
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img:
|
||||||
|
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: det
|
||||||
|
algorithm: SAST
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet_SAST
|
||||||
|
layers: 50
|
||||||
|
Neck:
|
||||||
|
name: SASTFPN
|
||||||
|
with_cab: True
|
||||||
|
Head:
|
||||||
|
name: SASTHead
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: SASTLoss
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
# name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
# warmup_epoch: 0
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: SASTPostProcess
|
||||||
|
score_thresh: 0.5
|
||||||
|
sample_pts_num: 2
|
||||||
|
nms_thresh: 0.2
|
||||||
|
expand_scale: 1.0
|
||||||
|
shrink_ratio_of_width: 0.3
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/
|
||||||
|
label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train_label_json.txt]
|
||||||
|
data_ratio_list: [0.5, 0.5]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- SASTProcessTrain:
|
||||||
|
image_shape: [512, 512]
|
||||||
|
min_crop_side_ratio: 0.3
|
||||||
|
min_crop_size: 24
|
||||||
|
min_text_size: 4
|
||||||
|
max_text_size: 512
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'score_map', 'border_map', 'training_mask', 'tvo_map', 'tco_map'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 4
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
resize_long: 1536
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -0,0 +1,109 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 5000
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/sast_r50_vd_tt/
|
||||||
|
save_epoch_step: 1000
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [4000, 5000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
load_static_weights: True
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img:
|
||||||
|
save_res_path: ./output/sast_r50_vd_tt/predicts_sast.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: det
|
||||||
|
algorithm: SAST
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet_SAST
|
||||||
|
layers: 50
|
||||||
|
Neck:
|
||||||
|
name: SASTFPN
|
||||||
|
with_cab: True
|
||||||
|
Head:
|
||||||
|
name: SASTHead
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: SASTLoss
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
# name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
# warmup_epoch: 0
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: SASTPostProcess
|
||||||
|
score_thresh: 0.5
|
||||||
|
sample_pts_num: 6
|
||||||
|
nms_thresh: 0.2
|
||||||
|
expand_scale: 1.2
|
||||||
|
shrink_ratio_of_width: 0.2
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
label_file_list: [./train_data/icdar2013/train_label_json.txt, ./train_data/icdar2015/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt]
|
||||||
|
ratio_list: [0.1, 0.45, 0.3, 0.15]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- SASTProcessTrain:
|
||||||
|
image_shape: [512, 512]
|
||||||
|
min_crop_side_ratio: 0.3
|
||||||
|
min_crop_size: 24
|
||||||
|
min_text_size: 4
|
||||||
|
max_text_size: 512
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'score_map', 'border_map', 'training_mask', 'tvo_map', 'tco_map'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 4
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/total_text_icdar_14pt/test_label_json.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
resize_long: 768
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -15,7 +15,7 @@ Global:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img:
|
infer_img:
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
character_dict_path: ppocr/utils/dict/ic15_dict.txt
|
||||||
character_type: ch
|
character_type: ch
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img:
|
infer_img:
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/french_dict.txt
|
character_dict_path: ppocr/utils/dict/french_dict.txt
|
||||||
character_type: french
|
character_type: french
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img:
|
infer_img:
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/german_dict.txt
|
character_dict_path: ppocr/utils/dict/german_dict.txt
|
||||||
character_type: german
|
character_type: german
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img:
|
infer_img:
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/japan_dict.txt
|
character_dict_path: ppocr/utils/dict/japan_dict.txt
|
||||||
character_type: japan
|
character_type: japan
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img:
|
infer_img:
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/korean_dict.txt
|
character_dict_path: ppocr/utils/dict/korean_dict.txt
|
||||||
character_type: korean
|
character_type: korean
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
|
|
|
@ -261,6 +261,61 @@ im_show.save('result.jpg')
|
||||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
|
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 使用网络图片或者numpy数组作为输入
|
||||||
|
|
||||||
|
1. 网络图片
|
||||||
|
|
||||||
|
代码使用
|
||||||
|
```python
|
||||||
|
from paddleocr import PaddleOCR, draw_ocr
|
||||||
|
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||||
|
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||||
|
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||||
|
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
|
||||||
|
result = ocr.ocr(img_path, cls=True)
|
||||||
|
for line in result:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
# 显示结果
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save('result.jpg')
|
||||||
|
```
|
||||||
|
命令行模式
|
||||||
|
```bash
|
||||||
|
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||||
|
```
|
||||||
|
|
||||||
|
2. numpy数组
|
||||||
|
仅通过代码使用时支持numpy数组作为输入
|
||||||
|
```python
|
||||||
|
from paddleocr import PaddleOCR, draw_ocr
|
||||||
|
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||||
|
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||||
|
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||||
|
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消
|
||||||
|
result = ocr.ocr(img_path, cls=True)
|
||||||
|
for line in result:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
# 显示结果
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save('result.jpg')
|
||||||
|
```
|
||||||
|
|
||||||
## 参数说明
|
## 参数说明
|
||||||
|
|
||||||
| 字段 | 说明 | 默认值 |
|
| 字段 | 说明 | 默认值 |
|
||||||
|
@ -285,6 +340,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
||||||
| max_text_length | 识别算法能识别的最大文字长度 | 25 |
|
| max_text_length | 识别算法能识别的最大文字长度 | 25 |
|
||||||
| rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt |
|
| rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt |
|
||||||
| use_space_char | 是否识别空格 | TRUE |
|
| use_space_char | 是否识别空格 | TRUE |
|
||||||
|
| drop_score | 对输出按照分数(来自于识别模型)进行过滤,低于此分数的不返回 | 0.5 |
|
||||||
| use_angle_cls | 是否加载分类模型 | FALSE |
|
| use_angle_cls | 是否加载分类模型 | FALSE |
|
||||||
| cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None |
|
| cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None |
|
||||||
| cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" |
|
| cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" |
|
||||||
|
@ -295,4 +351,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
||||||
| lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch |
|
| lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch |
|
||||||
| det | 前向时使用启动检测 | TRUE |
|
| det | 前向时使用启动检测 | TRUE |
|
||||||
| rec | 前向时是否启动识别 | TRUE |
|
| rec | 前向时是否启动识别 | TRUE |
|
||||||
| cls | 前向时是否启动分类 | FALSE |
|
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
|
||||||
|
|
|
@ -271,6 +271,59 @@ im_show.save('result.jpg')
|
||||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
|
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Use web images or numpy array as input
|
||||||
|
|
||||||
|
1. Web image
|
||||||
|
|
||||||
|
Use by code
|
||||||
|
```python
|
||||||
|
from paddleocr import PaddleOCR, draw_ocr
|
||||||
|
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||||
|
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
|
||||||
|
result = ocr.ocr(img_path, cls=True)
|
||||||
|
for line in result:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
# show result
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save('result.jpg')
|
||||||
|
```
|
||||||
|
Use by command line
|
||||||
|
```bash
|
||||||
|
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Numpy array
|
||||||
|
Support numpy array as input only when used by code
|
||||||
|
|
||||||
|
```python
|
||||||
|
from paddleocr import PaddleOCR, draw_ocr
|
||||||
|
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||||
|
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line
|
||||||
|
result = ocr.ocr(img_path, cls=True)
|
||||||
|
for line in result:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
# show result
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
boxes = [line[0] for line in result]
|
||||||
|
txts = [line[1][0] for line in result]
|
||||||
|
scores = [line[1][1] for line in result]
|
||||||
|
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||||
|
im_show = Image.fromarray(im_show)
|
||||||
|
im_show.save('result.jpg')
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Parameter Description
|
## Parameter Description
|
||||||
|
|
||||||
| Parameter | Description | Default value |
|
| Parameter | Description | Default value |
|
||||||
|
@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
||||||
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
|
| max_text_length | The maximum text length that the recognition algorithm can recognize | 25 |
|
||||||
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
|
| rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt |
|
||||||
| use_space_char | Whether to recognize spaces | TRUE |
|
| use_space_char | Whether to recognize spaces | TRUE |
|
||||||
|
| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 |
|
||||||
| use_angle_cls | Whether to load classification model | FALSE |
|
| use_angle_cls | Whether to load classification model | FALSE |
|
||||||
| cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
|
| cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None |
|
||||||
| cls_image_shape | image shape of classification algorithm | "3,48,192" |
|
| cls_image_shape | image shape of classification algorithm | "3,48,192" |
|
||||||
|
@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
||||||
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
|
| lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch |
|
||||||
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
|
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
|
||||||
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
|
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
|
||||||
| cls | Enable classification when `ppocr.ocr` func exec | FALSE |
|
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
|
||||||
|
|
176
paddleocr.py
176
paddleocr.py
|
@ -26,17 +26,50 @@ import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from tools.infer import predict_system
|
from tools.infer import predict_system
|
||||||
from ppocr.utils.utility import initial_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
logger = initial_logger()
|
logger = get_logger()
|
||||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||||
|
|
||||||
__all__ = ['PaddleOCR']
|
__all__ = ['PaddleOCR']
|
||||||
|
|
||||||
model_params = {
|
model_urls = {
|
||||||
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
|
'det':
|
||||||
'rec':
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar',
|
||||||
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
|
'rec': {
|
||||||
|
'ch': {
|
||||||
|
'url':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar',
|
||||||
|
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
||||||
|
},
|
||||||
|
'en': {
|
||||||
|
'url':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar',
|
||||||
|
'dict_path': './ppocr/utils/ic15_dict.txt'
|
||||||
|
},
|
||||||
|
'french': {
|
||||||
|
'url':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar',
|
||||||
|
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
||||||
|
},
|
||||||
|
'german': {
|
||||||
|
'url':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar',
|
||||||
|
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
||||||
|
},
|
||||||
|
'korean': {
|
||||||
|
'url':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar',
|
||||||
|
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
||||||
|
},
|
||||||
|
'japan': {
|
||||||
|
'url':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar',
|
||||||
|
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'cls':
|
||||||
|
'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar'
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORT_DET_MODEL = ['DB']
|
SUPPORT_DET_MODEL = ['DB']
|
||||||
|
@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path):
|
||||||
progress_bar.update(len(data))
|
progress_bar.update(len(data))
|
||||||
file.write(data)
|
file.write(data)
|
||||||
progress_bar.close()
|
progress_bar.close()
|
||||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||||
logger.error("ERROR, something went wrong")
|
logger.error("Something went wrong while downloading models")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,13 +117,14 @@ def maybe_download(model_storage_directory, url):
|
||||||
os.remove(tmp_path)
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args(mMain=True, add_help=True):
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
return v.lower() in ("true", "t", "1")
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
if mMain:
|
||||||
|
parser = argparse.ArgumentParser(add_help=add_help)
|
||||||
# params for prediction engine
|
# params for prediction engine
|
||||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||||
|
@ -101,7 +135,8 @@ def parse_args():
|
||||||
parser.add_argument("--image_dir", type=str)
|
parser.add_argument("--image_dir", type=str)
|
||||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||||
parser.add_argument("--det_model_dir", type=str, default=None)
|
parser.add_argument("--det_model_dir", type=str, default=None)
|
||||||
parser.add_argument("--det_max_side_len", type=float, default=960)
|
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||||
|
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||||
|
|
||||||
# DB parmas
|
# DB parmas
|
||||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||||
|
@ -120,17 +155,64 @@ def parse_args():
|
||||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||||
parser.add_argument("--max_text_length", type=int, default=25)
|
parser.add_argument("--max_text_length", type=int, default=25)
|
||||||
parser.add_argument(
|
parser.add_argument("--rec_char_dict_path", type=str, default=None)
|
||||||
"--rec_char_dict_path",
|
|
||||||
type=str,
|
|
||||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
|
||||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||||
|
|
||||||
|
# params for text classifier
|
||||||
|
parser.add_argument("--cls_model_dir", type=str, default=None)
|
||||||
|
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||||
|
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||||
|
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||||
|
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||||
|
|
||||||
|
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||||
|
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||||
|
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||||
|
|
||||||
|
parser.add_argument("--lang", type=str, default='ch')
|
||||||
parser.add_argument("--det", type=str2bool, default=True)
|
parser.add_argument("--det", type=str2bool, default=True)
|
||||||
parser.add_argument("--rec", type=str2bool, default=True)
|
parser.add_argument("--rec", type=str2bool, default=True)
|
||||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
else:
|
||||||
|
return argparse.Namespace(use_gpu=True,
|
||||||
|
ir_optim=True,
|
||||||
|
use_tensorrt=False,
|
||||||
|
gpu_mem=8000,
|
||||||
|
image_dir='',
|
||||||
|
det_algorithm='DB',
|
||||||
|
det_model_dir=None,
|
||||||
|
det_limit_side_len=960,
|
||||||
|
det_limit_type='max',
|
||||||
|
det_db_thresh=0.3,
|
||||||
|
det_db_box_thresh=0.5,
|
||||||
|
det_db_unclip_ratio=2.0,
|
||||||
|
det_east_score_thresh=0.8,
|
||||||
|
det_east_cover_thresh=0.1,
|
||||||
|
det_east_nms_thresh=0.2,
|
||||||
|
rec_algorithm='CRNN',
|
||||||
|
rec_model_dir=None,
|
||||||
|
rec_image_shape="3, 32, 320",
|
||||||
|
rec_char_type='ch',
|
||||||
|
rec_batch_num=30,
|
||||||
|
max_text_length=25,
|
||||||
|
rec_char_dict_path=None,
|
||||||
|
use_space_char=True,
|
||||||
|
drop_score=0.5,
|
||||||
|
cls_model_dir=None,
|
||||||
|
cls_image_shape="3, 48, 192",
|
||||||
|
label_list=['0', '180'],
|
||||||
|
cls_batch_num=30,
|
||||||
|
cls_thresh=0.9,
|
||||||
|
enable_mkldnn=False,
|
||||||
|
use_zero_copy_run=False,
|
||||||
|
use_pdserving=False,
|
||||||
|
lang='ch',
|
||||||
|
det=True,
|
||||||
|
rec=True,
|
||||||
|
use_angle_cls=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PaddleOCR(predict_system.TextSystem):
|
class PaddleOCR(predict_system.TextSystem):
|
||||||
|
@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
args:
|
args:
|
||||||
**kwargs: other params show in paddleocr --help
|
**kwargs: other params show in paddleocr --help
|
||||||
"""
|
"""
|
||||||
postprocess_params = parse_args()
|
postprocess_params = parse_args(mMain=False, add_help=False)
|
||||||
postprocess_params.__dict__.update(**kwargs)
|
postprocess_params.__dict__.update(**kwargs)
|
||||||
|
self.use_angle_cls = postprocess_params.use_angle_cls
|
||||||
|
lang = postprocess_params.lang
|
||||||
|
assert lang in model_urls[
|
||||||
|
'rec'], 'param lang must in {}, but got {}'.format(
|
||||||
|
model_urls['rec'].keys(), lang)
|
||||||
|
if postprocess_params.rec_char_dict_path is None:
|
||||||
|
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||||
|
'dict_path']
|
||||||
|
|
||||||
# init model dir
|
# init model dir
|
||||||
if postprocess_params.det_model_dir is None:
|
if postprocess_params.det_model_dir is None:
|
||||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
|
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
|
||||||
if postprocess_params.rec_model_dir is None:
|
if postprocess_params.rec_model_dir is None:
|
||||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec')
|
postprocess_params.rec_model_dir = os.path.join(
|
||||||
|
BASE_DIR, 'rec/{}'.format(lang))
|
||||||
|
if postprocess_params.cls_model_dir is None:
|
||||||
|
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||||
print(postprocess_params)
|
print(postprocess_params)
|
||||||
# download model
|
# download model
|
||||||
maybe_download(postprocess_params.det_model_dir, model_params['det'])
|
maybe_download(postprocess_params.det_model_dir, model_urls['det'])
|
||||||
maybe_download(postprocess_params.rec_model_dir, model_params['rec'])
|
maybe_download(postprocess_params.rec_model_dir,
|
||||||
|
model_urls['rec'][lang]['url'])
|
||||||
|
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
||||||
|
|
||||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||||
|
@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
# init det_model and rec_model
|
# init det_model and rec_model
|
||||||
super().__init__(postprocess_params)
|
super().__init__(postprocess_params)
|
||||||
|
|
||||||
def ocr(self, img, det=True, rec=True):
|
def ocr(self, img, det=True, rec=True, cls=False):
|
||||||
"""
|
"""
|
||||||
ocr with paddleocr
|
ocr with paddleocr
|
||||||
args:
|
args:
|
||||||
|
@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
rec: use text recognition or not, if false, only det will be exec. default is True
|
rec: use text recognition or not, if false, only det will be exec. default is True
|
||||||
"""
|
"""
|
||||||
assert isinstance(img, (np.ndarray, list, str))
|
assert isinstance(img, (np.ndarray, list, str))
|
||||||
|
if isinstance(img, list) and det == True:
|
||||||
|
logger.error('When input a list of images, det must be false')
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
self.use_angle_cls = cls
|
||||||
if isinstance(img, str):
|
if isinstance(img, str):
|
||||||
|
# download net image
|
||||||
|
if img.startswith('http'):
|
||||||
|
download_with_progressbar(img, 'tmp.jpg')
|
||||||
|
img = 'tmp.jpg'
|
||||||
image_file = img
|
image_file = img
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
if not flag:
|
if not flag:
|
||||||
|
@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.error("error in loading image:{}".format(image_file))
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
return None
|
return None
|
||||||
|
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
if det and rec:
|
if det and rec:
|
||||||
dt_boxes, rec_res = self.__call__(img)
|
dt_boxes, rec_res = self.__call__(img)
|
||||||
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||||
|
@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
else:
|
else:
|
||||||
if not isinstance(img, list):
|
if not isinstance(img, list):
|
||||||
img = [img]
|
img = [img]
|
||||||
|
if self.use_angle_cls:
|
||||||
|
img, cls_res, elapse = self.text_classifier(img)
|
||||||
|
if not rec:
|
||||||
|
return cls_res
|
||||||
rec_res, elapse = self.text_recognizer(img)
|
rec_res, elapse = self.text_recognizer(img)
|
||||||
return rec_res
|
return rec_res
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# for com
|
# for cmd
|
||||||
args = parse_args()
|
args = parse_args(mMain=True)
|
||||||
|
image_dir = args.image_dir
|
||||||
|
if image_dir.startswith('http'):
|
||||||
|
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||||
|
image_file_list = ['tmp.jpg']
|
||||||
|
else:
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
if len(image_file_list) == 0:
|
if len(image_file_list) == 0:
|
||||||
logger.error('no images find in {}'.format(args.image_dir))
|
logger.error('no images find in {}'.format(args.image_dir))
|
||||||
return
|
return
|
||||||
ocr_engine = PaddleOCR()
|
|
||||||
|
ocr_engine = PaddleOCR(**(args.__dict__))
|
||||||
for img_path in image_file_list:
|
for img_path in image_file_list:
|
||||||
print(img_path)
|
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||||
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
|
result = ocr_engine.ocr(img_path,
|
||||||
|
det=args.det,
|
||||||
|
rec=args.rec,
|
||||||
|
cls=args.use_angle_cls)
|
||||||
|
if result is not None:
|
||||||
for line in result:
|
for line in result:
|
||||||
print(line)
|
logger.info(line)
|
||||||
|
|
|
@ -26,6 +26,9 @@ from .randaugment import RandAugment
|
||||||
from .operators import *
|
from .operators import *
|
||||||
from .label_ops import *
|
from .label_ops import *
|
||||||
|
|
||||||
|
from .east_process import *
|
||||||
|
from .sast_process import *
|
||||||
|
|
||||||
|
|
||||||
def transform(data, ops=None):
|
def transform(data, ops=None):
|
||||||
""" transform """
|
""" transform """
|
||||||
|
|
|
@ -0,0 +1,439 @@
|
||||||
|
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
__all__ = ['EASTProcessTrain']
|
||||||
|
|
||||||
|
|
||||||
|
class EASTProcessTrain(object):
|
||||||
|
def __init__(self,
|
||||||
|
image_shape = [512, 512],
|
||||||
|
background_ratio = 0.125,
|
||||||
|
min_crop_side_ratio = 0.1,
|
||||||
|
min_text_size = 10,
|
||||||
|
**kwargs):
|
||||||
|
self.input_size = image_shape[1]
|
||||||
|
self.random_scale = np.array([0.5, 1, 2.0, 3.0])
|
||||||
|
self.background_ratio = background_ratio
|
||||||
|
self.min_crop_side_ratio = min_crop_side_ratio
|
||||||
|
self.min_text_size = min_text_size
|
||||||
|
|
||||||
|
def preprocess(self, im):
|
||||||
|
input_size = self.input_size
|
||||||
|
im_shape = im.shape
|
||||||
|
im_size_min = np.min(im_shape[0:2])
|
||||||
|
im_size_max = np.max(im_shape[0:2])
|
||||||
|
im_scale = float(input_size) / float(im_size_max)
|
||||||
|
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale)
|
||||||
|
img_mean = [0.485, 0.456, 0.406]
|
||||||
|
img_std = [0.229, 0.224, 0.225]
|
||||||
|
# im = im[:, :, ::-1].astype(np.float32)
|
||||||
|
im = im / 255
|
||||||
|
im -= img_mean
|
||||||
|
im /= img_std
|
||||||
|
new_h, new_w, _ = im.shape
|
||||||
|
im_padded = np.zeros((input_size, input_size, 3), dtype=np.float32)
|
||||||
|
im_padded[:new_h, :new_w, :] = im
|
||||||
|
im_padded = im_padded.transpose((2, 0, 1))
|
||||||
|
im_padded = im_padded[np.newaxis, :]
|
||||||
|
return im_padded, im_scale
|
||||||
|
|
||||||
|
def rotate_im_poly(self, im, text_polys):
|
||||||
|
"""
|
||||||
|
rotate image with 90 / 180 / 270 degre
|
||||||
|
"""
|
||||||
|
im_w, im_h = im.shape[1], im.shape[0]
|
||||||
|
dst_im = im.copy()
|
||||||
|
dst_polys = []
|
||||||
|
rand_degree_ratio = np.random.rand()
|
||||||
|
rand_degree_cnt = 1
|
||||||
|
if 0.333 < rand_degree_ratio < 0.666:
|
||||||
|
rand_degree_cnt = 2
|
||||||
|
elif rand_degree_ratio > 0.666:
|
||||||
|
rand_degree_cnt = 3
|
||||||
|
for i in range(rand_degree_cnt):
|
||||||
|
dst_im = np.rot90(dst_im)
|
||||||
|
rot_degree = -90 * rand_degree_cnt
|
||||||
|
rot_angle = rot_degree * math.pi / 180.0
|
||||||
|
n_poly = text_polys.shape[0]
|
||||||
|
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||||
|
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||||
|
for i in range(n_poly):
|
||||||
|
wordBB = text_polys[i]
|
||||||
|
poly = []
|
||||||
|
for j in range(4):
|
||||||
|
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||||
|
dx = math.cos(rot_angle) * (sx - cx)\
|
||||||
|
- math.sin(rot_angle) * (sy - cy) + ncx
|
||||||
|
dy = math.sin(rot_angle) * (sx - cx)\
|
||||||
|
+ math.cos(rot_angle) * (sy - cy) + ncy
|
||||||
|
poly.append([dx, dy])
|
||||||
|
dst_polys.append(poly)
|
||||||
|
dst_polys = np.array(dst_polys, dtype=np.float32)
|
||||||
|
return dst_im, dst_polys
|
||||||
|
|
||||||
|
def polygon_area(self, poly):
|
||||||
|
"""
|
||||||
|
compute area of a polygon
|
||||||
|
:param poly:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||||
|
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||||
|
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||||
|
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||||
|
return np.sum(edge) / 2.
|
||||||
|
|
||||||
|
def check_and_validate_polys(self, polys, tags, img_height, img_width):
|
||||||
|
"""
|
||||||
|
check so that the text poly is in the same direction,
|
||||||
|
and also filter some invalid polygons
|
||||||
|
:param polys:
|
||||||
|
:param tags:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
h, w = img_height, img_width
|
||||||
|
if polys.shape[0] == 0:
|
||||||
|
return polys
|
||||||
|
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||||
|
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||||
|
|
||||||
|
validated_polys = []
|
||||||
|
validated_tags = []
|
||||||
|
for poly, tag in zip(polys, tags):
|
||||||
|
p_area = self.polygon_area(poly)
|
||||||
|
#invalid poly
|
||||||
|
if abs(p_area) < 1:
|
||||||
|
continue
|
||||||
|
if p_area > 0:
|
||||||
|
#'poly in wrong direction'
|
||||||
|
if not tag:
|
||||||
|
tag = True #reversed cases should be ignore
|
||||||
|
poly = poly[(0, 3, 2, 1), :]
|
||||||
|
validated_polys.append(poly)
|
||||||
|
validated_tags.append(tag)
|
||||||
|
return np.array(validated_polys), np.array(validated_tags)
|
||||||
|
|
||||||
|
def draw_img_polys(self, img, polys):
|
||||||
|
if len(img.shape) == 4:
|
||||||
|
img = np.squeeze(img, axis=0)
|
||||||
|
if img.shape[0] == 3:
|
||||||
|
img = img.transpose((1, 2, 0))
|
||||||
|
img[:, :, 2] += 123.68
|
||||||
|
img[:, :, 1] += 116.78
|
||||||
|
img[:, :, 0] += 103.94
|
||||||
|
cv2.imwrite("tmp.jpg", img)
|
||||||
|
img = cv2.imread("tmp.jpg")
|
||||||
|
for box in polys:
|
||||||
|
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||||
|
cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
|
||||||
|
import random
|
||||||
|
ino = random.randint(0, 100)
|
||||||
|
cv2.imwrite("tmp_%d.jpg" % ino, img)
|
||||||
|
return
|
||||||
|
|
||||||
|
def shrink_poly(self, poly, r):
|
||||||
|
"""
|
||||||
|
fit a poly inside the origin poly, maybe bugs here...
|
||||||
|
used for generate the score map
|
||||||
|
:param poly: the text poly
|
||||||
|
:param r: r in the paper
|
||||||
|
:return: the shrinked poly
|
||||||
|
"""
|
||||||
|
# shrink ratio
|
||||||
|
R = 0.3
|
||||||
|
# find the longer pair
|
||||||
|
dist0 = np.linalg.norm(poly[0] - poly[1])
|
||||||
|
dist1 = np.linalg.norm(poly[2] - poly[3])
|
||||||
|
dist2 = np.linalg.norm(poly[0] - poly[3])
|
||||||
|
dist3 = np.linalg.norm(poly[1] - poly[2])
|
||||||
|
if dist0 + dist1 > dist2 + dist3:
|
||||||
|
# first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
|
||||||
|
## p0, p1
|
||||||
|
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
||||||
|
(poly[1][0] - poly[0][0]))
|
||||||
|
poly[0][0] += R * r[0] * np.cos(theta)
|
||||||
|
poly[0][1] += R * r[0] * np.sin(theta)
|
||||||
|
poly[1][0] -= R * r[1] * np.cos(theta)
|
||||||
|
poly[1][1] -= R * r[1] * np.sin(theta)
|
||||||
|
## p2, p3
|
||||||
|
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
||||||
|
(poly[2][0] - poly[3][0]))
|
||||||
|
poly[3][0] += R * r[3] * np.cos(theta)
|
||||||
|
poly[3][1] += R * r[3] * np.sin(theta)
|
||||||
|
poly[2][0] -= R * r[2] * np.cos(theta)
|
||||||
|
poly[2][1] -= R * r[2] * np.sin(theta)
|
||||||
|
## p0, p3
|
||||||
|
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
||||||
|
(poly[3][1] - poly[0][1]))
|
||||||
|
poly[0][0] += R * r[0] * np.sin(theta)
|
||||||
|
poly[0][1] += R * r[0] * np.cos(theta)
|
||||||
|
poly[3][0] -= R * r[3] * np.sin(theta)
|
||||||
|
poly[3][1] -= R * r[3] * np.cos(theta)
|
||||||
|
## p1, p2
|
||||||
|
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
||||||
|
(poly[2][1] - poly[1][1]))
|
||||||
|
poly[1][0] += R * r[1] * np.sin(theta)
|
||||||
|
poly[1][1] += R * r[1] * np.cos(theta)
|
||||||
|
poly[2][0] -= R * r[2] * np.sin(theta)
|
||||||
|
poly[2][1] -= R * r[2] * np.cos(theta)
|
||||||
|
else:
|
||||||
|
## p0, p3
|
||||||
|
# print poly
|
||||||
|
theta = np.arctan2((poly[3][0] - poly[0][0]),
|
||||||
|
(poly[3][1] - poly[0][1]))
|
||||||
|
poly[0][0] += R * r[0] * np.sin(theta)
|
||||||
|
poly[0][1] += R * r[0] * np.cos(theta)
|
||||||
|
poly[3][0] -= R * r[3] * np.sin(theta)
|
||||||
|
poly[3][1] -= R * r[3] * np.cos(theta)
|
||||||
|
## p1, p2
|
||||||
|
theta = np.arctan2((poly[2][0] - poly[1][0]),
|
||||||
|
(poly[2][1] - poly[1][1]))
|
||||||
|
poly[1][0] += R * r[1] * np.sin(theta)
|
||||||
|
poly[1][1] += R * r[1] * np.cos(theta)
|
||||||
|
poly[2][0] -= R * r[2] * np.sin(theta)
|
||||||
|
poly[2][1] -= R * r[2] * np.cos(theta)
|
||||||
|
## p0, p1
|
||||||
|
theta = np.arctan2((poly[1][1] - poly[0][1]),
|
||||||
|
(poly[1][0] - poly[0][0]))
|
||||||
|
poly[0][0] += R * r[0] * np.cos(theta)
|
||||||
|
poly[0][1] += R * r[0] * np.sin(theta)
|
||||||
|
poly[1][0] -= R * r[1] * np.cos(theta)
|
||||||
|
poly[1][1] -= R * r[1] * np.sin(theta)
|
||||||
|
## p2, p3
|
||||||
|
theta = np.arctan2((poly[2][1] - poly[3][1]),
|
||||||
|
(poly[2][0] - poly[3][0]))
|
||||||
|
poly[3][0] += R * r[3] * np.cos(theta)
|
||||||
|
poly[3][1] += R * r[3] * np.sin(theta)
|
||||||
|
poly[2][0] -= R * r[2] * np.cos(theta)
|
||||||
|
poly[2][1] -= R * r[2] * np.sin(theta)
|
||||||
|
return poly
|
||||||
|
|
||||||
|
def generate_quad(self, im_size, polys, tags):
|
||||||
|
"""
|
||||||
|
Generate quadrangle.
|
||||||
|
"""
|
||||||
|
h, w = im_size
|
||||||
|
poly_mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
score_map = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
# (x1, y1, ..., x4, y4, short_edge_norm)
|
||||||
|
geo_map = np.zeros((h, w, 9), dtype=np.float32)
|
||||||
|
# mask used during traning, to ignore some hard areas
|
||||||
|
training_mask = np.ones((h, w), dtype=np.uint8)
|
||||||
|
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||||
|
poly = poly_tag[0]
|
||||||
|
tag = poly_tag[1]
|
||||||
|
|
||||||
|
r = [None, None, None, None]
|
||||||
|
for i in range(4):
|
||||||
|
dist1 = np.linalg.norm(poly[i] - poly[(i + 1) % 4])
|
||||||
|
dist2 = np.linalg.norm(poly[i] - poly[(i - 1) % 4])
|
||||||
|
r[i] = min(dist1, dist2)
|
||||||
|
# score map
|
||||||
|
shrinked_poly = self.shrink_poly(
|
||||||
|
poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
|
||||||
|
cv2.fillPoly(score_map, shrinked_poly, 1)
|
||||||
|
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
|
||||||
|
# if the poly is too small, then ignore it during training
|
||||||
|
poly_h = min(
|
||||||
|
np.linalg.norm(poly[0] - poly[3]),
|
||||||
|
np.linalg.norm(poly[1] - poly[2]))
|
||||||
|
poly_w = min(
|
||||||
|
np.linalg.norm(poly[0] - poly[1]),
|
||||||
|
np.linalg.norm(poly[2] - poly[3]))
|
||||||
|
if min(poly_h, poly_w) < self.min_text_size:
|
||||||
|
cv2.fillPoly(training_mask,
|
||||||
|
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
cv2.fillPoly(training_mask,
|
||||||
|
poly.astype(np.int32)[np.newaxis, :, :], 0)
|
||||||
|
|
||||||
|
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
|
||||||
|
# geo map.
|
||||||
|
y_in_poly = xy_in_poly[:, 0]
|
||||||
|
x_in_poly = xy_in_poly[:, 1]
|
||||||
|
poly[:, 0] = np.minimum(np.maximum(poly[:, 0], 0), w)
|
||||||
|
poly[:, 1] = np.minimum(np.maximum(poly[:, 1], 0), h)
|
||||||
|
for pno in range(4):
|
||||||
|
geo_channel_beg = pno * 2
|
||||||
|
geo_map[y_in_poly, x_in_poly, geo_channel_beg] =\
|
||||||
|
x_in_poly - poly[pno, 0]
|
||||||
|
geo_map[y_in_poly, x_in_poly, geo_channel_beg+1] =\
|
||||||
|
y_in_poly - poly[pno, 1]
|
||||||
|
geo_map[y_in_poly, x_in_poly, 8] = \
|
||||||
|
1.0 / max(min(poly_h, poly_w), 1.0)
|
||||||
|
return score_map, geo_map, training_mask
|
||||||
|
|
||||||
|
def crop_area(self,
|
||||||
|
im,
|
||||||
|
polys,
|
||||||
|
tags,
|
||||||
|
crop_background=False,
|
||||||
|
max_tries=50):
|
||||||
|
"""
|
||||||
|
make random crop from the input image
|
||||||
|
:param im:
|
||||||
|
:param polys:
|
||||||
|
:param tags:
|
||||||
|
:param crop_background:
|
||||||
|
:param max_tries:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
h, w, _ = im.shape
|
||||||
|
pad_h = h // 10
|
||||||
|
pad_w = w // 10
|
||||||
|
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||||
|
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||||
|
for poly in polys:
|
||||||
|
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||||
|
minx = np.min(poly[:, 0])
|
||||||
|
maxx = np.max(poly[:, 0])
|
||||||
|
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||||
|
miny = np.min(poly[:, 1])
|
||||||
|
maxy = np.max(poly[:, 1])
|
||||||
|
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||||
|
# ensure the cropped area not across a text
|
||||||
|
h_axis = np.where(h_array == 0)[0]
|
||||||
|
w_axis = np.where(w_array == 0)[0]
|
||||||
|
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||||
|
return im, polys, tags
|
||||||
|
|
||||||
|
for i in range(max_tries):
|
||||||
|
xx = np.random.choice(w_axis, size=2)
|
||||||
|
xmin = np.min(xx) - pad_w
|
||||||
|
xmax = np.max(xx) - pad_w
|
||||||
|
xmin = np.clip(xmin, 0, w - 1)
|
||||||
|
xmax = np.clip(xmax, 0, w - 1)
|
||||||
|
yy = np.random.choice(h_axis, size=2)
|
||||||
|
ymin = np.min(yy) - pad_h
|
||||||
|
ymax = np.max(yy) - pad_h
|
||||||
|
ymin = np.clip(ymin, 0, h - 1)
|
||||||
|
ymax = np.clip(ymax, 0, h - 1)
|
||||||
|
if xmax - xmin < self.min_crop_side_ratio * w or \
|
||||||
|
ymax - ymin < self.min_crop_side_ratio * h:
|
||||||
|
# area too small
|
||||||
|
continue
|
||||||
|
if polys.shape[0] != 0:
|
||||||
|
poly_axis_in_area = (polys[:, :, 0] >= xmin)\
|
||||||
|
& (polys[:, :, 0] <= xmax)\
|
||||||
|
& (polys[:, :, 1] >= ymin)\
|
||||||
|
& (polys[:, :, 1] <= ymax)
|
||||||
|
selected_polys = np.where(
|
||||||
|
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||||
|
else:
|
||||||
|
selected_polys = []
|
||||||
|
|
||||||
|
if len(selected_polys) == 0:
|
||||||
|
# no text in this area
|
||||||
|
if crop_background:
|
||||||
|
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||||
|
polys = []
|
||||||
|
tags = []
|
||||||
|
return im, polys, tags
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||||
|
polys = polys[selected_polys]
|
||||||
|
tags = tags[selected_polys]
|
||||||
|
polys[:, :, 0] -= xmin
|
||||||
|
polys[:, :, 1] -= ymin
|
||||||
|
return im, polys, tags
|
||||||
|
return im, polys, tags
|
||||||
|
|
||||||
|
def crop_background_infor(self, im, text_polys, text_tags):
|
||||||
|
im, text_polys, text_tags = self.crop_area(
|
||||||
|
im, text_polys, text_tags, crop_background=True)
|
||||||
|
|
||||||
|
if len(text_polys) > 0:
|
||||||
|
return None
|
||||||
|
# pad and resize image
|
||||||
|
input_size = self.input_size
|
||||||
|
im, ratio = self.preprocess(im)
|
||||||
|
score_map = np.zeros((input_size, input_size), dtype=np.float32)
|
||||||
|
geo_map = np.zeros((input_size, input_size, 9), dtype=np.float32)
|
||||||
|
training_mask = np.ones((input_size, input_size), dtype=np.float32)
|
||||||
|
return im, score_map, geo_map, training_mask
|
||||||
|
|
||||||
|
def crop_foreground_infor(self, im, text_polys, text_tags):
|
||||||
|
im, text_polys, text_tags = self.crop_area(
|
||||||
|
im, text_polys, text_tags, crop_background=False)
|
||||||
|
|
||||||
|
if text_polys.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
#continue for all ignore case
|
||||||
|
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||||
|
return None
|
||||||
|
# pad and resize image
|
||||||
|
input_size = self.input_size
|
||||||
|
im, ratio = self.preprocess(im)
|
||||||
|
text_polys[:, :, 0] *= ratio
|
||||||
|
text_polys[:, :, 1] *= ratio
|
||||||
|
_, _, new_h, new_w = im.shape
|
||||||
|
# print(im.shape)
|
||||||
|
# self.draw_img_polys(im, text_polys)
|
||||||
|
score_map, geo_map, training_mask = self.generate_quad(
|
||||||
|
(new_h, new_w), text_polys, text_tags)
|
||||||
|
return im, score_map, geo_map, training_mask
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
im = data['image']
|
||||||
|
text_polys = data['polys']
|
||||||
|
text_tags = data['ignore_tags']
|
||||||
|
if im is None:
|
||||||
|
return None
|
||||||
|
if text_polys.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
#add rotate cases
|
||||||
|
if np.random.rand() < 0.5:
|
||||||
|
im, text_polys = self.rotate_im_poly(im, text_polys)
|
||||||
|
h, w, _ = im.shape
|
||||||
|
text_polys, text_tags = self.check_and_validate_polys(text_polys,
|
||||||
|
text_tags, h, w)
|
||||||
|
if text_polys.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# random scale this image
|
||||||
|
rd_scale = np.random.choice(self.random_scale)
|
||||||
|
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||||
|
text_polys *= rd_scale
|
||||||
|
if np.random.rand() < self.background_ratio:
|
||||||
|
outs = self.crop_background_infor(im, text_polys, text_tags)
|
||||||
|
else:
|
||||||
|
outs = self.crop_foreground_infor(im, text_polys, text_tags)
|
||||||
|
|
||||||
|
if outs is None:
|
||||||
|
return None
|
||||||
|
im, score_map, geo_map, training_mask = outs
|
||||||
|
score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
|
||||||
|
geo_map = np.swapaxes(geo_map, 1, 2)
|
||||||
|
geo_map = np.swapaxes(geo_map, 1, 0)
|
||||||
|
geo_map = geo_map[:, ::4, ::4].astype(np.float32)
|
||||||
|
training_mask = training_mask[np.newaxis, ::4, ::4]
|
||||||
|
training_mask = training_mask.astype(np.float32)
|
||||||
|
|
||||||
|
data['image'] = im[0]
|
||||||
|
data['score_map'] = score_map
|
||||||
|
data['geo_map'] = geo_map
|
||||||
|
data['training_mask'] = training_mask
|
||||||
|
# print(im.shape, score_map.shape, geo_map.shape, training_mask.shape)
|
||||||
|
return data
|
|
@ -52,6 +52,7 @@ class DetLabelEncode(object):
|
||||||
txt_tags.append(True)
|
txt_tags.append(True)
|
||||||
else:
|
else:
|
||||||
txt_tags.append(False)
|
txt_tags.append(False)
|
||||||
|
boxes = self.expand_points_num(boxes)
|
||||||
boxes = np.array(boxes, dtype=np.float32)
|
boxes = np.array(boxes, dtype=np.float32)
|
||||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||||
|
|
||||||
|
@ -70,6 +71,17 @@ class DetLabelEncode(object):
|
||||||
rect[3] = pts[np.argmax(diff)]
|
rect[3] = pts[np.argmax(diff)]
|
||||||
return rect
|
return rect
|
||||||
|
|
||||||
|
def expand_points_num(self, boxes):
|
||||||
|
max_points_num = 0
|
||||||
|
for box in boxes:
|
||||||
|
if len(box) > max_points_num:
|
||||||
|
max_points_num = len(box)
|
||||||
|
ex_boxes = []
|
||||||
|
for box in boxes:
|
||||||
|
ex_box = box + [box[-1]] * (max_points_num - len(box))
|
||||||
|
ex_boxes.append(ex_box)
|
||||||
|
return ex_boxes
|
||||||
|
|
||||||
|
|
||||||
class BaseRecLabelEncode(object):
|
class BaseRecLabelEncode(object):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
@ -79,7 +91,9 @@ class BaseRecLabelEncode(object):
|
||||||
character_dict_path=None,
|
character_dict_path=None,
|
||||||
character_type='ch',
|
character_type='ch',
|
||||||
use_space_char=False):
|
use_space_char=False):
|
||||||
support_character_type = ['ch', 'en', 'en_sensitive']
|
support_character_type = [
|
||||||
|
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
|
||||||
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, self.character_str)
|
support_character_type, self.character_str)
|
||||||
|
|
||||||
|
@ -87,7 +101,7 @@ class BaseRecLabelEncode(object):
|
||||||
if character_type == "en":
|
if character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type == "ch":
|
elif character_type in ["ch", "french", "german", "japan", "korean"]:
|
||||||
self.character_str = ""
|
self.character_str = ""
|
||||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
||||||
with open(character_dict_path, "rb") as fin:
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
|
|
@ -120,26 +120,37 @@ class DetResizeForTest(object):
|
||||||
if 'limit_side_len' in kwargs:
|
if 'limit_side_len' in kwargs:
|
||||||
self.limit_side_len = kwargs['limit_side_len']
|
self.limit_side_len = kwargs['limit_side_len']
|
||||||
self.limit_type = kwargs.get('limit_type', 'min')
|
self.limit_type = kwargs.get('limit_type', 'min')
|
||||||
|
if 'resize_long' in kwargs:
|
||||||
|
self.resize_type = 2
|
||||||
|
self.resize_long = kwargs.get('resize_long', 960)
|
||||||
else:
|
else:
|
||||||
self.limit_side_len = 736
|
self.limit_side_len = 736
|
||||||
self.limit_type = 'min'
|
self.limit_type = 'min'
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
img = data['image']
|
img = data['image']
|
||||||
|
src_h, src_w, _ = img.shape
|
||||||
|
|
||||||
if self.resize_type == 0:
|
if self.resize_type == 0:
|
||||||
img, shape = self.resize_image_type0(img)
|
# img, shape = self.resize_image_type0(img)
|
||||||
|
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||||
|
elif self.resize_type == 2:
|
||||||
|
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||||
else:
|
else:
|
||||||
img, shape = self.resize_image_type1(img)
|
# img, shape = self.resize_image_type1(img)
|
||||||
|
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||||
data['image'] = img
|
data['image'] = img
|
||||||
data['shape'] = shape
|
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def resize_image_type1(self, img):
|
def resize_image_type1(self, img):
|
||||||
resize_h, resize_w = self.image_shape
|
resize_h, resize_w = self.image_shape
|
||||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||||
|
ratio_h = float(resize_h) / ori_h
|
||||||
|
ratio_w = float(resize_w) / ori_w
|
||||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||||
return img, np.array([ori_h, ori_w])
|
# return img, np.array([ori_h, ori_w])
|
||||||
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
||||||
def resize_image_type0(self, img):
|
def resize_image_type0(self, img):
|
||||||
"""
|
"""
|
||||||
|
@ -182,4 +193,31 @@ class DetResizeForTest(object):
|
||||||
except:
|
except:
|
||||||
print(img.shape, resize_w, resize_h)
|
print(img.shape, resize_w, resize_h)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
return img, np.array([h, w])
|
ratio_h = resize_h / float(h)
|
||||||
|
ratio_w = resize_w / float(w)
|
||||||
|
# return img, np.array([h, w])
|
||||||
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
||||||
|
def resize_image_type2(self, img):
|
||||||
|
h, w, _ = img.shape
|
||||||
|
|
||||||
|
resize_w = w
|
||||||
|
resize_h = h
|
||||||
|
|
||||||
|
# Fix the longer side
|
||||||
|
if resize_h > resize_w:
|
||||||
|
ratio = float(self.resize_long) / resize_h
|
||||||
|
else:
|
||||||
|
ratio = float(self.resize_long) / resize_w
|
||||||
|
|
||||||
|
resize_h = int(resize_h * ratio)
|
||||||
|
resize_w = int(resize_w * ratio)
|
||||||
|
|
||||||
|
max_stride = 128
|
||||||
|
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||||
|
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||||
|
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||||
|
ratio_h = resize_h / float(h)
|
||||||
|
ratio_w = resize_w / float(w)
|
||||||
|
|
||||||
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
|
@ -0,0 +1,689 @@
|
||||||
|
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
__all__ = ['SASTProcessTrain']
|
||||||
|
|
||||||
|
|
||||||
|
class SASTProcessTrain(object):
|
||||||
|
def __init__(self,
|
||||||
|
image_shape = [512, 512],
|
||||||
|
min_crop_size = 24,
|
||||||
|
min_crop_side_ratio = 0.3,
|
||||||
|
min_text_size = 10,
|
||||||
|
max_text_size = 512,
|
||||||
|
**kwargs):
|
||||||
|
self.input_size = image_shape[1]
|
||||||
|
self.min_crop_size = min_crop_size
|
||||||
|
self.min_crop_side_ratio = min_crop_side_ratio
|
||||||
|
self.min_text_size = min_text_size
|
||||||
|
self.max_text_size = max_text_size
|
||||||
|
|
||||||
|
def quad_area(self, poly):
|
||||||
|
"""
|
||||||
|
compute area of a polygon
|
||||||
|
:param poly:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
edge = [
|
||||||
|
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||||
|
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||||
|
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||||
|
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
|
||||||
|
]
|
||||||
|
return np.sum(edge) / 2.
|
||||||
|
|
||||||
|
def gen_quad_from_poly(self, poly):
|
||||||
|
"""
|
||||||
|
Generate min area quad from poly.
|
||||||
|
"""
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||||
|
if True:
|
||||||
|
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||||
|
center_point = rect[0]
|
||||||
|
box = np.array(cv2.boxPoints(rect))
|
||||||
|
|
||||||
|
first_point_idx = 0
|
||||||
|
min_dist = 1e4
|
||||||
|
for i in range(4):
|
||||||
|
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||||
|
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||||
|
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||||
|
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
first_point_idx = i
|
||||||
|
for i in range(4):
|
||||||
|
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||||
|
|
||||||
|
return min_area_quad
|
||||||
|
|
||||||
|
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||||
|
"""
|
||||||
|
check so that the text poly is in the same direction,
|
||||||
|
and also filter some invalid polygons
|
||||||
|
:param polys:
|
||||||
|
:param tags:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
(h, w) = xxx_todo_changeme
|
||||||
|
if polys.shape[0] == 0:
|
||||||
|
return polys, np.array([]), np.array([])
|
||||||
|
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||||
|
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||||
|
|
||||||
|
validated_polys = []
|
||||||
|
validated_tags = []
|
||||||
|
hv_tags = []
|
||||||
|
for poly, tag in zip(polys, tags):
|
||||||
|
quad = self.gen_quad_from_poly(poly)
|
||||||
|
p_area = self.quad_area(quad)
|
||||||
|
if abs(p_area) < 1:
|
||||||
|
print('invalid poly')
|
||||||
|
continue
|
||||||
|
if p_area > 0:
|
||||||
|
if tag == False:
|
||||||
|
print('poly in wrong direction')
|
||||||
|
tag = True # reversed cases should be ignore
|
||||||
|
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
|
||||||
|
quad = quad[(0, 3, 2, 1), :]
|
||||||
|
|
||||||
|
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2])
|
||||||
|
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
||||||
|
hv_tag = 1
|
||||||
|
|
||||||
|
if len_w * 2.0 < len_h:
|
||||||
|
hv_tag = 0
|
||||||
|
|
||||||
|
validated_polys.append(poly)
|
||||||
|
validated_tags.append(tag)
|
||||||
|
hv_tags.append(hv_tag)
|
||||||
|
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
|
||||||
|
|
||||||
|
def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25):
|
||||||
|
"""
|
||||||
|
make random crop from the input image
|
||||||
|
:param im:
|
||||||
|
:param polys:
|
||||||
|
:param tags:
|
||||||
|
:param crop_background:
|
||||||
|
:param max_tries: 50 -> 25
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
h, w, _ = im.shape
|
||||||
|
pad_h = h // 10
|
||||||
|
pad_w = w // 10
|
||||||
|
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||||
|
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||||
|
for poly in polys:
|
||||||
|
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||||
|
minx = np.min(poly[:, 0])
|
||||||
|
maxx = np.max(poly[:, 0])
|
||||||
|
w_array[minx + pad_w: maxx + pad_w] = 1
|
||||||
|
miny = np.min(poly[:, 1])
|
||||||
|
maxy = np.max(poly[:, 1])
|
||||||
|
h_array[miny + pad_h: maxy + pad_h] = 1
|
||||||
|
# ensure the cropped area not across a text
|
||||||
|
h_axis = np.where(h_array == 0)[0]
|
||||||
|
w_axis = np.where(w_array == 0)[0]
|
||||||
|
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||||
|
return im, polys, tags, hv_tags
|
||||||
|
for i in range(max_tries):
|
||||||
|
xx = np.random.choice(w_axis, size=2)
|
||||||
|
xmin = np.min(xx) - pad_w
|
||||||
|
xmax = np.max(xx) - pad_w
|
||||||
|
xmin = np.clip(xmin, 0, w - 1)
|
||||||
|
xmax = np.clip(xmax, 0, w - 1)
|
||||||
|
yy = np.random.choice(h_axis, size=2)
|
||||||
|
ymin = np.min(yy) - pad_h
|
||||||
|
ymax = np.max(yy) - pad_h
|
||||||
|
ymin = np.clip(ymin, 0, h - 1)
|
||||||
|
ymax = np.clip(ymax, 0, h - 1)
|
||||||
|
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
|
||||||
|
# ymax - ymin < ARGS.min_crop_side_ratio * h:
|
||||||
|
if xmax - xmin < self.min_crop_size or \
|
||||||
|
ymax - ymin < self.min_crop_size:
|
||||||
|
# area too small
|
||||||
|
continue
|
||||||
|
if polys.shape[0] != 0:
|
||||||
|
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||||
|
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||||
|
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||||
|
else:
|
||||||
|
selected_polys = []
|
||||||
|
if len(selected_polys) == 0:
|
||||||
|
# no text in this area
|
||||||
|
if crop_background:
|
||||||
|
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
||||||
|
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
im = im[ymin: ymax + 1, xmin: xmax + 1, :]
|
||||||
|
polys = polys[selected_polys]
|
||||||
|
tags = tags[selected_polys]
|
||||||
|
hv_tags = hv_tags[selected_polys]
|
||||||
|
polys[:, :, 0] -= xmin
|
||||||
|
polys[:, :, 1] -= ymin
|
||||||
|
return im, polys, tags, hv_tags
|
||||||
|
|
||||||
|
return im, polys, tags, hv_tags
|
||||||
|
|
||||||
|
def generate_direction_map(self, poly_quads, direction_map):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
width_list = []
|
||||||
|
height_list = []
|
||||||
|
for quad in poly_quads:
|
||||||
|
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||||
|
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||||
|
width_list.append(quad_w)
|
||||||
|
height_list.append(quad_h)
|
||||||
|
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||||
|
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
||||||
|
|
||||||
|
for quad in poly_quads:
|
||||||
|
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||||
|
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||||
|
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)]))
|
||||||
|
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label)
|
||||||
|
return direction_map
|
||||||
|
|
||||||
|
def calculate_average_height(self, poly_quads):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
height_list = []
|
||||||
|
for quad in poly_quads:
|
||||||
|
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||||
|
height_list.append(quad_h)
|
||||||
|
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||||
|
return average_height
|
||||||
|
|
||||||
|
def generate_tcl_label(self, hw, polys, tags, ds_ratio,
|
||||||
|
tcl_ratio=0.3, shrink_ratio_of_width=0.15):
|
||||||
|
"""
|
||||||
|
Generate polygon.
|
||||||
|
"""
|
||||||
|
h, w = hw
|
||||||
|
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||||
|
polys = polys * ds_ratio
|
||||||
|
|
||||||
|
score_map = np.zeros((h, w,), dtype=np.float32)
|
||||||
|
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||||
|
training_mask = np.ones((h, w,), dtype=np.float32)
|
||||||
|
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32)
|
||||||
|
|
||||||
|
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||||
|
poly = poly_tag[0]
|
||||||
|
tag = poly_tag[1]
|
||||||
|
|
||||||
|
# generate min_area_quad
|
||||||
|
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||||
|
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||||
|
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||||
|
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||||
|
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||||
|
|
||||||
|
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||||
|
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
# continue
|
||||||
|
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||||
|
else:
|
||||||
|
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||||
|
tcl_quads = self.poly2quads(tcl_poly)
|
||||||
|
poly_quads = self.poly2quads(poly)
|
||||||
|
# stcl map
|
||||||
|
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width,
|
||||||
|
expand_height_ratio=1.0 / tcl_ratio)
|
||||||
|
# generate tcl map
|
||||||
|
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
|
||||||
|
|
||||||
|
# generate tbo map
|
||||||
|
for idx, quad in enumerate(stcl_quads):
|
||||||
|
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||||
|
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||||
|
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map)
|
||||||
|
return score_map, tbo_map, training_mask
|
||||||
|
|
||||||
|
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
|
||||||
|
"""
|
||||||
|
Generate tcl map, tvo map and tbo map.
|
||||||
|
"""
|
||||||
|
h, w = hw
|
||||||
|
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||||
|
polys = polys * ds_ratio
|
||||||
|
poly_mask = np.zeros((h, w), dtype=np.float32)
|
||||||
|
|
||||||
|
tvo_map = np.ones((9, h, w), dtype=np.float32)
|
||||||
|
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
|
||||||
|
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
|
||||||
|
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
|
||||||
|
|
||||||
|
# tco map
|
||||||
|
tco_map = np.ones((3, h, w), dtype=np.float32)
|
||||||
|
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
|
||||||
|
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
|
||||||
|
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
|
||||||
|
|
||||||
|
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
|
||||||
|
|
||||||
|
for poly, poly_tag in zip(polys, tags):
|
||||||
|
|
||||||
|
if poly_tag == True:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# adjust point order for vertical poly
|
||||||
|
poly = self.adjust_point(poly)
|
||||||
|
|
||||||
|
# generate min_area_quad
|
||||||
|
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||||
|
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||||
|
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||||
|
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||||
|
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||||
|
|
||||||
|
# generate tcl map and text, 128 * 128
|
||||||
|
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||||
|
|
||||||
|
# generate poly_tv_xy_map
|
||||||
|
for idx in range(4):
|
||||||
|
cv2.fillPoly(poly_tv_xy_map[2 * idx],
|
||||||
|
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||||
|
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||||
|
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1],
|
||||||
|
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||||
|
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||||
|
|
||||||
|
# generate poly_tc_xy_map
|
||||||
|
for idx in range(2):
|
||||||
|
cv2.fillPoly(poly_tc_xy_map[idx],
|
||||||
|
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx]))
|
||||||
|
|
||||||
|
# generate poly_short_edge_map
|
||||||
|
cv2.fillPoly(poly_short_edge_map,
|
||||||
|
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||||
|
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||||
|
|
||||||
|
# generate poly_mask and training_mask
|
||||||
|
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1)
|
||||||
|
|
||||||
|
tvo_map *= poly_mask
|
||||||
|
tvo_map[:8] -= poly_tv_xy_map
|
||||||
|
tvo_map[-1] /= poly_short_edge_map
|
||||||
|
tvo_map = tvo_map.transpose((1, 2, 0))
|
||||||
|
|
||||||
|
tco_map *= poly_mask
|
||||||
|
tco_map[:2] -= poly_tc_xy_map
|
||||||
|
tco_map[-1] /= poly_short_edge_map
|
||||||
|
tco_map = tco_map.transpose((1, 2, 0))
|
||||||
|
|
||||||
|
return tvo_map, tco_map
|
||||||
|
|
||||||
|
def adjust_point(self, poly):
|
||||||
|
"""
|
||||||
|
adjust point order.
|
||||||
|
"""
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
if point_num == 4:
|
||||||
|
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||||
|
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||||
|
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||||
|
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||||
|
|
||||||
|
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||||
|
poly = poly[[1, 2, 3, 0], :]
|
||||||
|
|
||||||
|
elif point_num > 4:
|
||||||
|
vector_1 = poly[0] - poly[1]
|
||||||
|
vector_2 = poly[1] - poly[2]
|
||||||
|
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||||
|
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||||
|
|
||||||
|
if abs(theta) > (70 / 180 * math.pi):
|
||||||
|
index = list(range(1, point_num)) + [0]
|
||||||
|
poly = poly[np.array(index), :]
|
||||||
|
return poly
|
||||||
|
|
||||||
|
def gen_min_area_quad_from_poly(self, poly):
|
||||||
|
"""
|
||||||
|
Generate min area quad from poly.
|
||||||
|
"""
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||||
|
if point_num == 4:
|
||||||
|
min_area_quad = poly
|
||||||
|
center_point = np.sum(poly, axis=0) / 4
|
||||||
|
else:
|
||||||
|
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||||
|
center_point = rect[0]
|
||||||
|
box = np.array(cv2.boxPoints(rect))
|
||||||
|
|
||||||
|
first_point_idx = 0
|
||||||
|
min_dist = 1e4
|
||||||
|
for i in range(4):
|
||||||
|
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||||
|
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||||
|
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||||
|
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||||
|
if dist < min_dist:
|
||||||
|
min_dist = dist
|
||||||
|
first_point_idx = i
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||||
|
|
||||||
|
return min_area_quad, center_point
|
||||||
|
|
||||||
|
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||||
|
"""
|
||||||
|
Generate shrink_quad_along_width.
|
||||||
|
"""
|
||||||
|
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||||
|
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||||
|
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||||
|
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||||
|
|
||||||
|
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0):
|
||||||
|
"""
|
||||||
|
shrink poly with given length.
|
||||||
|
"""
|
||||||
|
upper_edge_list = []
|
||||||
|
|
||||||
|
def get_cut_info(edge_len_list, cut_len):
|
||||||
|
for idx, edge_len in enumerate(edge_len_list):
|
||||||
|
cut_len -= edge_len
|
||||||
|
if cut_len <= 0.000001:
|
||||||
|
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||||
|
return idx, ratio
|
||||||
|
|
||||||
|
for quad in quads:
|
||||||
|
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||||
|
upper_edge_list.append(upper_edge_len)
|
||||||
|
|
||||||
|
# length of left edge and right edge.
|
||||||
|
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
|
||||||
|
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
|
||||||
|
|
||||||
|
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
|
||||||
|
# shrinking length
|
||||||
|
upper_len_left = shrink_length
|
||||||
|
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||||
|
|
||||||
|
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||||
|
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||||
|
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||||
|
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||||
|
|
||||||
|
out_quad_list = []
|
||||||
|
if left_idx == right_idx:
|
||||||
|
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||||
|
else:
|
||||||
|
out_quad_list.append(left_quad)
|
||||||
|
for idx in range(left_idx + 1, right_idx):
|
||||||
|
out_quad_list.append(quads[idx])
|
||||||
|
out_quad_list.append(right_quad)
|
||||||
|
|
||||||
|
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||||
|
|
||||||
|
def vector_angle(self, A, B):
|
||||||
|
"""
|
||||||
|
Calculate the angle between vector AB and x-axis positive direction.
|
||||||
|
"""
|
||||||
|
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||||
|
return np.arctan2(*AB)
|
||||||
|
|
||||||
|
def theta_line_cross_point(self, theta, point):
|
||||||
|
"""
|
||||||
|
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||||
|
"""
|
||||||
|
x, y = point
|
||||||
|
cos = np.cos(theta)
|
||||||
|
sin = np.sin(theta)
|
||||||
|
return [sin, -cos, cos * y - sin * x]
|
||||||
|
|
||||||
|
def line_cross_two_point(self, A, B):
|
||||||
|
"""
|
||||||
|
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||||
|
"""
|
||||||
|
angle = self.vector_angle(A, B)
|
||||||
|
return self.theta_line_cross_point(angle, A)
|
||||||
|
|
||||||
|
def average_angle(self, poly):
|
||||||
|
"""
|
||||||
|
Calculate the average angle between left and right edge in given poly.
|
||||||
|
"""
|
||||||
|
p0, p1, p2, p3 = poly
|
||||||
|
angle30 = self.vector_angle(p3, p0)
|
||||||
|
angle21 = self.vector_angle(p2, p1)
|
||||||
|
return (angle30 + angle21) / 2
|
||||||
|
|
||||||
|
def line_cross_point(self, line1, line2):
|
||||||
|
"""
|
||||||
|
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||||
|
"""
|
||||||
|
a1, b1, c1 = line1
|
||||||
|
a2, b2, c2 = line2
|
||||||
|
d = a1 * b2 - a2 * b1
|
||||||
|
|
||||||
|
if d == 0:
|
||||||
|
#print("line1", line1)
|
||||||
|
#print("line2", line2)
|
||||||
|
print('Cross point does not exist')
|
||||||
|
return np.array([0, 0], dtype=np.float32)
|
||||||
|
else:
|
||||||
|
x = (b1 * c2 - b2 * c1) / d
|
||||||
|
y = (a2 * c1 - a1 * c2) / d
|
||||||
|
|
||||||
|
return np.array([x, y], dtype=np.float32)
|
||||||
|
|
||||||
|
def quad2tcl(self, poly, ratio):
|
||||||
|
"""
|
||||||
|
Generate center line by poly clock-wise point. (4, 2)
|
||||||
|
"""
|
||||||
|
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||||
|
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||||
|
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||||
|
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||||
|
|
||||||
|
def poly2tcl(self, poly, ratio):
|
||||||
|
"""
|
||||||
|
Generate center line by poly clock-wise point.
|
||||||
|
"""
|
||||||
|
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||||
|
tcl_poly = np.zeros_like(poly)
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
|
||||||
|
for idx in range(point_num // 2):
|
||||||
|
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
|
||||||
|
tcl_poly[idx] = point_pair[0]
|
||||||
|
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||||
|
return tcl_poly
|
||||||
|
|
||||||
|
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||||
|
"""
|
||||||
|
Generate tbo_map for give quad.
|
||||||
|
"""
|
||||||
|
# upper and lower line function: ax + by + c = 0;
|
||||||
|
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||||
|
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||||
|
|
||||||
|
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]))
|
||||||
|
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]))
|
||||||
|
|
||||||
|
# average angle of left and right line.
|
||||||
|
angle = self.average_angle(quad)
|
||||||
|
|
||||||
|
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||||
|
for y, x in xy_in_poly:
|
||||||
|
point = (x, y)
|
||||||
|
line = self.theta_line_cross_point(angle, point)
|
||||||
|
cross_point_upper = self.line_cross_point(up_line, line)
|
||||||
|
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||||
|
##FIX, offset reverse
|
||||||
|
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||||
|
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||||
|
tbo_map[y, x, 0] = upper_offset_y
|
||||||
|
tbo_map[y, x, 1] = upper_offset_x
|
||||||
|
tbo_map[y, x, 2] = lower_offset_y
|
||||||
|
tbo_map[y, x, 3] = lower_offset_x
|
||||||
|
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||||
|
return tbo_map
|
||||||
|
|
||||||
|
def poly2quads(self, poly):
|
||||||
|
"""
|
||||||
|
Split poly into quads.
|
||||||
|
"""
|
||||||
|
quad_list = []
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
|
||||||
|
# point pair
|
||||||
|
point_pair_list = []
|
||||||
|
for idx in range(point_num // 2):
|
||||||
|
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||||
|
point_pair_list.append(point_pair)
|
||||||
|
|
||||||
|
quad_num = point_num // 2 - 1
|
||||||
|
for idx in range(quad_num):
|
||||||
|
# reshape and adjust to clock-wise
|
||||||
|
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]])
|
||||||
|
|
||||||
|
return np.array(quad_list)
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
im = data['image']
|
||||||
|
text_polys = data['polys']
|
||||||
|
text_tags = data['ignore_tags']
|
||||||
|
if im is None:
|
||||||
|
return None
|
||||||
|
if text_polys.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
h, w, _ = im.shape
|
||||||
|
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
|
||||||
|
|
||||||
|
if text_polys.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
#set aspect ratio and keep area fix
|
||||||
|
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||||
|
asp_scale = np.random.choice(asp_scales)
|
||||||
|
|
||||||
|
if np.random.rand() < 0.5:
|
||||||
|
asp_scale = 1.0 / asp_scale
|
||||||
|
asp_scale = math.sqrt(asp_scale)
|
||||||
|
|
||||||
|
asp_wx = asp_scale
|
||||||
|
asp_hy = 1.0 / asp_scale
|
||||||
|
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||||
|
text_polys[:, :, 0] *= asp_wx
|
||||||
|
text_polys[:, :, 1] *= asp_hy
|
||||||
|
|
||||||
|
h, w, _ = im.shape
|
||||||
|
if max(h, w) > 2048:
|
||||||
|
rd_scale = 2048.0 / max(h, w)
|
||||||
|
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||||
|
text_polys *= rd_scale
|
||||||
|
h, w, _ = im.shape
|
||||||
|
if min(h, w) < 16:
|
||||||
|
return None
|
||||||
|
|
||||||
|
#no background
|
||||||
|
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
|
||||||
|
text_polys, text_tags, hv_tags, crop_background=False)
|
||||||
|
|
||||||
|
if text_polys.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
#continue for all ignore case
|
||||||
|
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||||
|
return None
|
||||||
|
new_h, new_w, _ = im.shape
|
||||||
|
if (new_h is None) or (new_w is None):
|
||||||
|
return None
|
||||||
|
#resize image
|
||||||
|
std_ratio = float(self.input_size) / max(new_w, new_h)
|
||||||
|
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||||
|
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||||
|
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||||
|
text_polys[:, :, 0] *= rz_scale
|
||||||
|
text_polys[:, :, 1] *= rz_scale
|
||||||
|
|
||||||
|
#add gaussian blur
|
||||||
|
if np.random.rand() < 0.1 * 0.5:
|
||||||
|
ks = np.random.permutation(5)[0] + 1
|
||||||
|
ks = int(ks/2)*2 + 1
|
||||||
|
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||||
|
#add brighter
|
||||||
|
if np.random.rand() < 0.1 * 0.5:
|
||||||
|
im = im * (1.0 + np.random.rand() * 0.5)
|
||||||
|
im = np.clip(im, 0.0, 255.0)
|
||||||
|
#add darker
|
||||||
|
if np.random.rand() < 0.1 * 0.5:
|
||||||
|
im = im * (1.0 - np.random.rand() * 0.5)
|
||||||
|
im = np.clip(im, 0.0, 255.0)
|
||||||
|
|
||||||
|
# Padding the im to [input_size, input_size]
|
||||||
|
new_h, new_w, _ = im.shape
|
||||||
|
if min(new_w, new_h) < self.input_size * 0.5:
|
||||||
|
return None
|
||||||
|
|
||||||
|
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
|
||||||
|
im_padded[:, :, 2] = 0.485 * 255
|
||||||
|
im_padded[:, :, 1] = 0.456 * 255
|
||||||
|
im_padded[:, :, 0] = 0.406 * 255
|
||||||
|
|
||||||
|
# Random the start position
|
||||||
|
del_h = self.input_size - new_h
|
||||||
|
del_w = self.input_size - new_w
|
||||||
|
sh, sw = 0, 0
|
||||||
|
if del_h > 1:
|
||||||
|
sh = int(np.random.rand() * del_h)
|
||||||
|
if del_w > 1:
|
||||||
|
sw = int(np.random.rand() * del_w)
|
||||||
|
|
||||||
|
# Padding
|
||||||
|
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy()
|
||||||
|
text_polys[:, :, 0] += sw
|
||||||
|
text_polys[:, :, 1] += sh
|
||||||
|
|
||||||
|
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size),
|
||||||
|
text_polys, text_tags, 0.25)
|
||||||
|
|
||||||
|
# SAST head
|
||||||
|
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25)
|
||||||
|
# print("test--------tvo_map shape:", tvo_map.shape)
|
||||||
|
|
||||||
|
im_padded[:, :, 2] -= 0.485 * 255
|
||||||
|
im_padded[:, :, 1] -= 0.456 * 255
|
||||||
|
im_padded[:, :, 0] -= 0.406 * 255
|
||||||
|
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||||
|
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||||
|
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||||
|
im_padded = im_padded.transpose((2, 0, 1))
|
||||||
|
|
||||||
|
data['image'] = im_padded[::-1, :, :]
|
||||||
|
data['score_map'] = score_map[np.newaxis, :, :]
|
||||||
|
data['border_map'] = border_map.transpose((2, 0, 1))
|
||||||
|
data['training_mask'] = training_mask[np.newaxis, :, :]
|
||||||
|
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
|
||||||
|
data['tco_map'] = tco_map.transpose((2, 0, 1))
|
||||||
|
return data
|
|
@ -18,6 +18,8 @@ import copy
|
||||||
def build_loss(config):
|
def build_loss(config):
|
||||||
# det loss
|
# det loss
|
||||||
from .det_db_loss import DBLoss
|
from .det_db_loss import DBLoss
|
||||||
|
from .det_east_loss import EASTLoss
|
||||||
|
from .det_sast_loss import SASTLoss
|
||||||
|
|
||||||
# rec loss
|
# rec loss
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
|
@ -25,7 +27,7 @@ def build_loss(config):
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
|
||||||
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss']
|
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from .det_basic_loss import DiceLoss
|
||||||
|
|
||||||
|
|
||||||
|
class EASTLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
eps=1e-6,
|
||||||
|
**kwargs):
|
||||||
|
super(EASTLoss, self).__init__()
|
||||||
|
self.dice_loss = DiceLoss(eps=eps)
|
||||||
|
|
||||||
|
def forward(self, predicts, labels):
|
||||||
|
l_score, l_geo, l_mask = labels[1:]
|
||||||
|
f_score = predicts['f_score']
|
||||||
|
f_geo = predicts['f_geo']
|
||||||
|
|
||||||
|
dice_loss = self.dice_loss(f_score, l_score, l_mask)
|
||||||
|
|
||||||
|
#smoooth_l1_loss
|
||||||
|
channels = 8
|
||||||
|
l_geo_split = paddle.split(
|
||||||
|
l_geo, num_or_sections=channels + 1, axis=1)
|
||||||
|
f_geo_split = paddle.split(f_geo, num_or_sections=channels, axis=1)
|
||||||
|
smooth_l1 = 0
|
||||||
|
for i in range(0, channels):
|
||||||
|
geo_diff = l_geo_split[i] - f_geo_split[i]
|
||||||
|
abs_geo_diff = paddle.abs(geo_diff)
|
||||||
|
smooth_l1_sign = paddle.less_than(abs_geo_diff, l_score)
|
||||||
|
smooth_l1_sign = paddle.cast(smooth_l1_sign, dtype='float32')
|
||||||
|
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
|
||||||
|
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
|
||||||
|
out_loss = l_geo_split[-1] / channels * in_loss * l_score
|
||||||
|
smooth_l1 += out_loss
|
||||||
|
smooth_l1_loss = paddle.mean(smooth_l1 * l_score)
|
||||||
|
|
||||||
|
dice_loss = dice_loss * 0.01
|
||||||
|
total_loss = dice_loss + smooth_l1_loss
|
||||||
|
losses = {"loss":total_loss, \
|
||||||
|
"dice_loss":dice_loss,\
|
||||||
|
"smooth_l1_loss":smooth_l1_loss}
|
||||||
|
return losses
|
|
@ -0,0 +1,121 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from .det_basic_loss import DiceLoss
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class SASTLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
eps=1e-6,
|
||||||
|
**kwargs):
|
||||||
|
super(SASTLoss, self).__init__()
|
||||||
|
self.dice_loss = DiceLoss(eps=eps)
|
||||||
|
|
||||||
|
def forward(self, predicts, labels):
|
||||||
|
"""
|
||||||
|
tcl_pos: N x 128 x 3
|
||||||
|
tcl_mask: N x 128 x 1
|
||||||
|
tcl_label: N x X list or LoDTensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
f_score = predicts['f_score']
|
||||||
|
f_border = predicts['f_border']
|
||||||
|
f_tvo = predicts['f_tvo']
|
||||||
|
f_tco = predicts['f_tco']
|
||||||
|
|
||||||
|
l_score, l_border, l_mask, l_tvo, l_tco = labels[1:]
|
||||||
|
|
||||||
|
#score_loss
|
||||||
|
intersection = paddle.sum(f_score * l_score * l_mask)
|
||||||
|
union = paddle.sum(f_score * l_mask) + paddle.sum(l_score * l_mask)
|
||||||
|
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
|
||||||
|
|
||||||
|
#border loss
|
||||||
|
l_border_split, l_border_norm = paddle.split(l_border, num_or_sections=[4, 1], axis=1)
|
||||||
|
f_border_split = f_border
|
||||||
|
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
|
||||||
|
l_border_norm_split = paddle.expand(x=l_border_norm, shape=border_ex_shape)
|
||||||
|
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
|
||||||
|
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
|
||||||
|
|
||||||
|
border_diff = l_border_split - f_border_split
|
||||||
|
abs_border_diff = paddle.abs(border_diff)
|
||||||
|
border_sign = abs_border_diff < 1.0
|
||||||
|
border_sign = paddle.cast(border_sign, dtype='float32')
|
||||||
|
border_sign.stop_gradient = True
|
||||||
|
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||||
|
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||||
|
border_out_loss = l_border_norm_split * border_in_loss
|
||||||
|
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||||
|
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
||||||
|
|
||||||
|
#tvo_loss
|
||||||
|
l_tvo_split, l_tvo_norm = paddle.split(l_tvo, num_or_sections=[8, 1], axis=1)
|
||||||
|
f_tvo_split = f_tvo
|
||||||
|
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
|
||||||
|
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
|
||||||
|
l_tvo_score = paddle.expand(x=l_score, shape=tvo_ex_shape)
|
||||||
|
l_tvo_mask = paddle.expand(x=l_mask, shape=tvo_ex_shape)
|
||||||
|
#
|
||||||
|
tvo_geo_diff = l_tvo_split - f_tvo_split
|
||||||
|
abs_tvo_geo_diff = paddle.abs(tvo_geo_diff)
|
||||||
|
tvo_sign = abs_tvo_geo_diff < 1.0
|
||||||
|
tvo_sign = paddle.cast(tvo_sign, dtype='float32')
|
||||||
|
tvo_sign.stop_gradient = True
|
||||||
|
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
|
||||||
|
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
|
||||||
|
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
|
||||||
|
tvo_loss = paddle.sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
|
||||||
|
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
|
||||||
|
|
||||||
|
#tco_loss
|
||||||
|
l_tco_split, l_tco_norm = paddle.split(l_tco, num_or_sections=[2, 1], axis=1)
|
||||||
|
f_tco_split = f_tco
|
||||||
|
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
|
||||||
|
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
|
||||||
|
l_tco_score = paddle.expand(x=l_score, shape=tco_ex_shape)
|
||||||
|
l_tco_mask = paddle.expand(x=l_mask, shape=tco_ex_shape)
|
||||||
|
|
||||||
|
tco_geo_diff = l_tco_split - f_tco_split
|
||||||
|
abs_tco_geo_diff = paddle.abs(tco_geo_diff)
|
||||||
|
tco_sign = abs_tco_geo_diff < 1.0
|
||||||
|
tco_sign = paddle.cast(tco_sign, dtype='float32')
|
||||||
|
tco_sign.stop_gradient = True
|
||||||
|
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
|
||||||
|
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
|
||||||
|
tco_out_loss = l_tco_norm_split * tco_in_loss
|
||||||
|
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
|
||||||
|
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
# total loss
|
||||||
|
tvo_lw, tco_lw = 1.5, 1.5
|
||||||
|
score_lw, border_lw = 1.0, 1.0
|
||||||
|
total_loss = score_loss * score_lw + border_loss * border_lw + \
|
||||||
|
tvo_loss * tvo_lw + tco_loss * tco_lw
|
||||||
|
|
||||||
|
losses = {'loss':total_loss, "score_loss":score_loss,\
|
||||||
|
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
|
||||||
|
return losses
|
|
@ -19,6 +19,7 @@ def build_backbone(config, model_type):
|
||||||
if model_type == 'det':
|
if model_type == 'det':
|
||||||
from .det_mobilenet_v3 import MobileNetV3
|
from .det_mobilenet_v3 import MobileNetV3
|
||||||
from .det_resnet_vd import ResNet
|
from .det_resnet_vd import ResNet
|
||||||
|
from .det_resnet_vd_sast import ResNet_SAST
|
||||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
||||||
elif model_type == 'rec' or model_type == 'cls':
|
elif model_type == 'rec' or model_type == 'cls':
|
||||||
from .rec_mobilenet_v3 import MobileNetV3
|
from .rec_mobilenet_v3 import MobileNetV3
|
||||||
|
|
|
@ -0,0 +1,285 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import ParamAttr
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ["ResNet_SAST"]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
is_vd_mode=False,
|
||||||
|
act=None,
|
||||||
|
name=None, ):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
|
||||||
|
self.is_vd_mode = is_vd_mode
|
||||||
|
self._pool2d_avg = nn.AvgPool2D(
|
||||||
|
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self._conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + "_weights"),
|
||||||
|
bias_attr=False)
|
||||||
|
if name == "conv1":
|
||||||
|
bn_name = "bn_" + name
|
||||||
|
else:
|
||||||
|
bn_name = "bn" + name[3:]
|
||||||
|
self._batch_norm = nn.BatchNorm(
|
||||||
|
out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||||
|
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||||
|
moving_mean_name=bn_name + '_mean',
|
||||||
|
moving_variance_name=bn_name + '_variance')
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.is_vd_mode:
|
||||||
|
inputs = self._pool2d_avg(inputs)
|
||||||
|
y = self._conv(inputs)
|
||||||
|
y = self._batch_norm(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
shortcut=True,
|
||||||
|
if_first=False,
|
||||||
|
name=None):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2b")
|
||||||
|
self.conv2 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
kernel_size=1,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2c")
|
||||||
|
|
||||||
|
if not shortcut:
|
||||||
|
self.short = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
is_vd_mode=False if if_first else True,
|
||||||
|
name=name + "_branch1")
|
||||||
|
|
||||||
|
self.shortcut = shortcut
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
y = self.conv0(inputs)
|
||||||
|
conv1 = self.conv1(y)
|
||||||
|
conv2 = self.conv2(conv1)
|
||||||
|
|
||||||
|
if self.shortcut:
|
||||||
|
short = inputs
|
||||||
|
else:
|
||||||
|
short = self.short(inputs)
|
||||||
|
y = paddle.add(x=short, y=conv2)
|
||||||
|
y = F.relu(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
shortcut=True,
|
||||||
|
if_first=False,
|
||||||
|
name=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2b")
|
||||||
|
|
||||||
|
if not shortcut:
|
||||||
|
self.short = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
is_vd_mode=False if if_first else True,
|
||||||
|
name=name + "_branch1")
|
||||||
|
|
||||||
|
self.shortcut = shortcut
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
y = self.conv0(inputs)
|
||||||
|
conv1 = self.conv1(y)
|
||||||
|
|
||||||
|
if self.shortcut:
|
||||||
|
short = inputs
|
||||||
|
else:
|
||||||
|
short = self.short(inputs)
|
||||||
|
y = paddle.add(x=short, y=conv1)
|
||||||
|
y = F.relu(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet_SAST(nn.Layer):
|
||||||
|
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||||
|
super(ResNet_SAST, self).__init__()
|
||||||
|
|
||||||
|
self.layers = layers
|
||||||
|
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||||
|
assert layers in supported_layers, \
|
||||||
|
"supported layers are {} but input layer is {}".format(
|
||||||
|
supported_layers, layers)
|
||||||
|
|
||||||
|
if layers == 18:
|
||||||
|
depth = [2, 2, 2, 2]
|
||||||
|
elif layers == 34 or layers == 50:
|
||||||
|
# depth = [3, 4, 6, 3]
|
||||||
|
depth = [3, 4, 6, 3, 3]
|
||||||
|
elif layers == 101:
|
||||||
|
depth = [3, 4, 23, 3]
|
||||||
|
elif layers == 152:
|
||||||
|
depth = [3, 8, 36, 3]
|
||||||
|
elif layers == 200:
|
||||||
|
depth = [3, 12, 48, 3]
|
||||||
|
# num_channels = [64, 256, 512,
|
||||||
|
# 1024] if layers >= 50 else [64, 64, 128, 256]
|
||||||
|
# num_filters = [64, 128, 256, 512]
|
||||||
|
num_channels = [64, 256, 512,
|
||||||
|
1024, 2048] if layers >= 50 else [64, 64, 128, 256]
|
||||||
|
num_filters = [64, 128, 256, 512, 512]
|
||||||
|
|
||||||
|
self.conv1_1 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
act='relu',
|
||||||
|
name="conv1_1")
|
||||||
|
self.conv1_2 = ConvBNLayer(
|
||||||
|
in_channels=32,
|
||||||
|
out_channels=32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act='relu',
|
||||||
|
name="conv1_2")
|
||||||
|
self.conv1_3 = ConvBNLayer(
|
||||||
|
in_channels=32,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act='relu',
|
||||||
|
name="conv1_3")
|
||||||
|
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.stages = []
|
||||||
|
self.out_channels = [3, 64]
|
||||||
|
if layers >= 50:
|
||||||
|
for block in range(len(depth)):
|
||||||
|
block_list = []
|
||||||
|
shortcut = False
|
||||||
|
for i in range(depth[block]):
|
||||||
|
if layers in [101, 152] and block == 2:
|
||||||
|
if i == 0:
|
||||||
|
conv_name = "res" + str(block + 2) + "a"
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
bottleneck_block = self.add_sublayer(
|
||||||
|
'bb_%d_%d' % (block, i),
|
||||||
|
BottleneckBlock(
|
||||||
|
in_channels=num_channels[block]
|
||||||
|
if i == 0 else num_filters[block] * 4,
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=2 if i == 0 and block != 0 else 1,
|
||||||
|
shortcut=shortcut,
|
||||||
|
if_first=block == i == 0,
|
||||||
|
name=conv_name))
|
||||||
|
shortcut = True
|
||||||
|
block_list.append(bottleneck_block)
|
||||||
|
self.out_channels.append(num_filters[block] * 4)
|
||||||
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
|
else:
|
||||||
|
for block in range(len(depth)):
|
||||||
|
block_list = []
|
||||||
|
shortcut = False
|
||||||
|
for i in range(depth[block]):
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
basic_block = self.add_sublayer(
|
||||||
|
'bb_%d_%d' % (block, i),
|
||||||
|
BasicBlock(
|
||||||
|
in_channels=num_channels[block]
|
||||||
|
if i == 0 else num_filters[block],
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=2 if i == 0 and block != 0 else 1,
|
||||||
|
shortcut=shortcut,
|
||||||
|
if_first=block == i == 0,
|
||||||
|
name=conv_name))
|
||||||
|
shortcut = True
|
||||||
|
block_list.append(basic_block)
|
||||||
|
self.out_channels.append(num_filters[block])
|
||||||
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
out = [inputs]
|
||||||
|
y = self.conv1_1(inputs)
|
||||||
|
y = self.conv1_2(y)
|
||||||
|
y = self.conv1_3(y)
|
||||||
|
out.append(y)
|
||||||
|
y = self.pool2d_max(y)
|
||||||
|
for block in self.stages:
|
||||||
|
y = block(y)
|
||||||
|
out.append(y)
|
||||||
|
return out
|
|
@ -18,13 +18,15 @@ __all__ = ['build_head']
|
||||||
def build_head(config):
|
def build_head(config):
|
||||||
# det head
|
# det head
|
||||||
from .det_db_head import DBHead
|
from .det_db_head import DBHead
|
||||||
|
from .det_east_head import EASTHead
|
||||||
|
from .det_sast_head import SASTHead
|
||||||
|
|
||||||
# rec head
|
# rec head
|
||||||
from .rec_ctc_head import CTCHead
|
from .rec_ctc_head import CTCHead
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = ['DBHead', 'CTCHead', 'ClsHead']
|
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,121 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||||
|
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||||
|
moving_mean_name="bn_" + name + "_mean",
|
||||||
|
moving_variance_name="bn_" + name + "_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EASTHead(nn.Layer):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, model_name, **kwargs):
|
||||||
|
super(EASTHead, self).__init__()
|
||||||
|
self.model_name = model_name
|
||||||
|
if self.model_name == "large":
|
||||||
|
num_outputs = [128, 64, 1, 8]
|
||||||
|
else:
|
||||||
|
num_outputs = [64, 32, 1, 8]
|
||||||
|
|
||||||
|
self.det_conv1 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=num_outputs[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="det_head1")
|
||||||
|
self.det_conv2 = ConvBNLayer(
|
||||||
|
in_channels=num_outputs[0],
|
||||||
|
out_channels=num_outputs[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="det_head2")
|
||||||
|
self.score_conv = ConvBNLayer(
|
||||||
|
in_channels=num_outputs[1],
|
||||||
|
out_channels=num_outputs[2],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
if_act=False,
|
||||||
|
act=None,
|
||||||
|
name="f_score")
|
||||||
|
self.geo_conv = ConvBNLayer(
|
||||||
|
in_channels=num_outputs[1],
|
||||||
|
out_channels=num_outputs[3],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
if_act=False,
|
||||||
|
act=None,
|
||||||
|
name="f_geo")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f_det = self.det_conv1(x)
|
||||||
|
f_det = self.det_conv2(f_det)
|
||||||
|
f_score = self.score_conv(f_det)
|
||||||
|
f_score = F.sigmoid(f_score)
|
||||||
|
f_geo = self.geo_conv(f_det)
|
||||||
|
f_geo = (F.sigmoid(f_geo) - 0.5) * 2 * 800
|
||||||
|
|
||||||
|
pred = {'f_score': f_score, 'f_geo': f_geo}
|
||||||
|
return pred
|
|
@ -0,0 +1,128 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||||
|
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||||
|
moving_mean_name="bn_" + name + "_mean",
|
||||||
|
moving_variance_name="bn_" + name + "_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SAST_Header1(nn.Layer):
|
||||||
|
def __init__(self, in_channels, **kwargs):
|
||||||
|
super(SAST_Header1, self).__init__()
|
||||||
|
out_channels = [64, 64, 128]
|
||||||
|
self.score_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_score1'),
|
||||||
|
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_score2'),
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_score3'),
|
||||||
|
ConvBNLayer(out_channels[2], 1, 3, 1, act=None, name='f_score4')
|
||||||
|
)
|
||||||
|
self.border_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_border1'),
|
||||||
|
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_border2'),
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_border3'),
|
||||||
|
ConvBNLayer(out_channels[2], 4, 3, 1, act=None, name='f_border4')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f_score = self.score_conv(x)
|
||||||
|
f_score = F.sigmoid(f_score)
|
||||||
|
f_border = self.border_conv(x)
|
||||||
|
return f_score, f_border
|
||||||
|
|
||||||
|
|
||||||
|
class SAST_Header2(nn.Layer):
|
||||||
|
def __init__(self, in_channels, **kwargs):
|
||||||
|
super(SAST_Header2, self).__init__()
|
||||||
|
out_channels = [64, 64, 128]
|
||||||
|
self.tvo_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tvo1'),
|
||||||
|
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tvo2'),
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tvo3'),
|
||||||
|
ConvBNLayer(out_channels[2], 8, 3, 1, act=None, name='f_tvo4')
|
||||||
|
)
|
||||||
|
self.tco_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(in_channels, out_channels[0], 1, 1, act='relu', name='f_tco1'),
|
||||||
|
ConvBNLayer(out_channels[0], out_channels[1], 3, 1, act='relu', name='f_tco2'),
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[2], 1, 1, act='relu', name='f_tco3'),
|
||||||
|
ConvBNLayer(out_channels[2], 2, 3, 1, act=None, name='f_tco4')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f_tvo = self.tvo_conv(x)
|
||||||
|
f_tco = self.tco_conv(x)
|
||||||
|
return f_tvo, f_tco
|
||||||
|
|
||||||
|
|
||||||
|
class SASTHead(nn.Layer):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, **kwargs):
|
||||||
|
super(SASTHead, self).__init__()
|
||||||
|
|
||||||
|
self.head1 = SAST_Header1(in_channels)
|
||||||
|
self.head2 = SAST_Header2(in_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f_score, f_border = self.head1(x)
|
||||||
|
f_tvo, f_tco = self.head2(x)
|
||||||
|
|
||||||
|
predicts = {}
|
||||||
|
predicts['f_score'] = f_score
|
||||||
|
predicts['f_border'] = f_border
|
||||||
|
predicts['f_tvo'] = f_tvo
|
||||||
|
predicts['f_tco'] = f_tco
|
||||||
|
return predicts
|
|
@ -16,8 +16,10 @@ __all__ = ['build_neck']
|
||||||
|
|
||||||
def build_neck(config):
|
def build_neck(config):
|
||||||
from .db_fpn import DBFPN
|
from .db_fpn import DBFPN
|
||||||
|
from .east_fpn import EASTFPN
|
||||||
|
from .sast_fpn import SASTFPN
|
||||||
from .rnn import SequenceEncoder
|
from .rnn import SequenceEncoder
|
||||||
support_dict = ['DBFPN', 'SequenceEncoder']
|
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||||
|
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||||
|
moving_mean_name="bn_" + name + "_mean",
|
||||||
|
moving_variance_name="bn_" + name + "_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DeConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(DeConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.deconv = nn.Conv2DTranspose(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||||
|
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||||
|
moving_mean_name="bn_" + name + "_mean",
|
||||||
|
moving_variance_name="bn_" + name + "_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.deconv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EASTFPN(nn.Layer):
|
||||||
|
def __init__(self, in_channels, model_name, **kwargs):
|
||||||
|
super(EASTFPN, self).__init__()
|
||||||
|
self.model_name = model_name
|
||||||
|
if self.model_name == "large":
|
||||||
|
self.out_channels = 128
|
||||||
|
else:
|
||||||
|
self.out_channels = 64
|
||||||
|
self.in_channels = in_channels[::-1]
|
||||||
|
self.h1_conv = ConvBNLayer(
|
||||||
|
in_channels=self.out_channels+self.in_channels[1],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_h_1")
|
||||||
|
self.h2_conv = ConvBNLayer(
|
||||||
|
in_channels=self.out_channels+self.in_channels[2],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_h_2")
|
||||||
|
self.h3_conv = ConvBNLayer(
|
||||||
|
in_channels=self.out_channels+self.in_channels[3],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_h_3")
|
||||||
|
self.g0_deconv = DeConvBNLayer(
|
||||||
|
in_channels=self.in_channels[0],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_g_0")
|
||||||
|
self.g1_deconv = DeConvBNLayer(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_g_1")
|
||||||
|
self.g2_deconv = DeConvBNLayer(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_g_2")
|
||||||
|
self.g3_conv = ConvBNLayer(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
if_act=True,
|
||||||
|
act='relu',
|
||||||
|
name="unet_g_3")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f = x[::-1]
|
||||||
|
|
||||||
|
h = f[0]
|
||||||
|
g = self.g0_deconv(h)
|
||||||
|
h = paddle.concat([g, f[1]], axis=1)
|
||||||
|
h = self.h1_conv(h)
|
||||||
|
g = self.g1_deconv(h)
|
||||||
|
h = paddle.concat([g, f[2]], axis=1)
|
||||||
|
h = self.h2_conv(h)
|
||||||
|
g = self.g2_deconv(h)
|
||||||
|
h = paddle.concat([g, f[3]], axis=1)
|
||||||
|
h = self.h3_conv(h)
|
||||||
|
g = self.g3_conv(h)
|
||||||
|
|
||||||
|
return g
|
|
@ -0,0 +1,284 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||||
|
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||||
|
moving_mean_name="bn_" + name + "_mean",
|
||||||
|
moving_variance_name="bn_" + name + "_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DeConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(DeConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.deconv = nn.Conv2DTranspose(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||||
|
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||||
|
moving_mean_name="bn_" + name + "_mean",
|
||||||
|
moving_variance_name="bn_" + name + "_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.deconv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FPN_Up_Fusion(nn.Layer):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(FPN_Up_Fusion, self).__init__()
|
||||||
|
in_channels = in_channels[::-1]
|
||||||
|
out_channels = [256, 256, 192, 192, 128]
|
||||||
|
|
||||||
|
self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 1, 1, act=None, name='fpn_up_h0')
|
||||||
|
self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 1, 1, act=None, name='fpn_up_h1')
|
||||||
|
self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 1, 1, act=None, name='fpn_up_h2')
|
||||||
|
self.h3_conv = ConvBNLayer(in_channels[3], out_channels[3], 1, 1, act=None, name='fpn_up_h3')
|
||||||
|
self.h4_conv = ConvBNLayer(in_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_h4')
|
||||||
|
|
||||||
|
self.g0_conv = DeConvBNLayer(out_channels[0], out_channels[1], 4, 2, act=None, name='fpn_up_g0')
|
||||||
|
|
||||||
|
self.g1_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_up_g1_1'),
|
||||||
|
DeConvBNLayer(out_channels[1], out_channels[2], 4, 2, act=None, name='fpn_up_g1_2')
|
||||||
|
)
|
||||||
|
self.g2_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_up_g2_1'),
|
||||||
|
DeConvBNLayer(out_channels[2], out_channels[3], 4, 2, act=None, name='fpn_up_g2_2')
|
||||||
|
)
|
||||||
|
self.g3_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(out_channels[3], out_channels[3], 3, 1, act='relu', name='fpn_up_g3_1'),
|
||||||
|
DeConvBNLayer(out_channels[3], out_channels[4], 4, 2, act=None, name='fpn_up_g3_2')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.g4_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(out_channels[4], out_channels[4], 3, 1, act='relu', name='fpn_up_fusion_1'),
|
||||||
|
ConvBNLayer(out_channels[4], out_channels[4], 1, 1, act=None, name='fpn_up_fusion_2')
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_relu(self, x1, x2):
|
||||||
|
x = paddle.add(x=x1, y=x2)
|
||||||
|
x = F.relu(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f = x[2:][::-1]
|
||||||
|
h0 = self.h0_conv(f[0])
|
||||||
|
h1 = self.h1_conv(f[1])
|
||||||
|
h2 = self.h2_conv(f[2])
|
||||||
|
h3 = self.h3_conv(f[3])
|
||||||
|
h4 = self.h4_conv(f[4])
|
||||||
|
|
||||||
|
g0 = self.g0_conv(h0)
|
||||||
|
g1 = self._add_relu(g0, h1)
|
||||||
|
g1 = self.g1_conv(g1)
|
||||||
|
g2 = self.g2_conv(self._add_relu(g1, h2))
|
||||||
|
g3 = self.g3_conv(self._add_relu(g2, h3))
|
||||||
|
g4 = self.g4_conv(self._add_relu(g3, h4))
|
||||||
|
|
||||||
|
return g4
|
||||||
|
|
||||||
|
|
||||||
|
class FPN_Down_Fusion(nn.Layer):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(FPN_Down_Fusion, self).__init__()
|
||||||
|
out_channels = [32, 64, 128]
|
||||||
|
|
||||||
|
self.h0_conv = ConvBNLayer(in_channels[0], out_channels[0], 3, 1, act=None, name='fpn_down_h0')
|
||||||
|
self.h1_conv = ConvBNLayer(in_channels[1], out_channels[1], 3, 1, act=None, name='fpn_down_h1')
|
||||||
|
self.h2_conv = ConvBNLayer(in_channels[2], out_channels[2], 3, 1, act=None, name='fpn_down_h2')
|
||||||
|
|
||||||
|
self.g0_conv = ConvBNLayer(out_channels[0], out_channels[1], 3, 2, act=None, name='fpn_down_g0')
|
||||||
|
|
||||||
|
self.g1_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[1], 3, 1, act='relu', name='fpn_down_g1_1'),
|
||||||
|
ConvBNLayer(out_channels[1], out_channels[2], 3, 2, act=None, name='fpn_down_g1_2')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.g2_conv = nn.Sequential(
|
||||||
|
ConvBNLayer(out_channels[2], out_channels[2], 3, 1, act='relu', name='fpn_down_fusion_1'),
|
||||||
|
ConvBNLayer(out_channels[2], out_channels[2], 1, 1, act=None, name='fpn_down_fusion_2')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
f = x[:3]
|
||||||
|
h0 = self.h0_conv(f[0])
|
||||||
|
h1 = self.h1_conv(f[1])
|
||||||
|
h2 = self.h2_conv(f[2])
|
||||||
|
g0 = self.g0_conv(h0)
|
||||||
|
g1 = paddle.add(x=g0, y=h1)
|
||||||
|
g1 = F.relu(g1)
|
||||||
|
g1 = self.g1_conv(g1)
|
||||||
|
g2 = paddle.add(x=g1, y=h2)
|
||||||
|
g2 = F.relu(g2)
|
||||||
|
g2 = self.g2_conv(g2)
|
||||||
|
return g2
|
||||||
|
|
||||||
|
|
||||||
|
class Cross_Attention(nn.Layer):
|
||||||
|
def __init__(self, in_channels):
|
||||||
|
super(Cross_Attention, self).__init__()
|
||||||
|
self.theta_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_theta')
|
||||||
|
self.phi_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_phi')
|
||||||
|
self.g_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act='relu', name='f_g')
|
||||||
|
|
||||||
|
self.fh_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_weight')
|
||||||
|
self.fh_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fh_sc')
|
||||||
|
|
||||||
|
self.fv_weight_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_weight')
|
||||||
|
self.fv_sc_conv = ConvBNLayer(in_channels, in_channels, 1, 1, act=None, name='fv_sc')
|
||||||
|
|
||||||
|
self.f_attn_conv = ConvBNLayer(in_channels * 2, in_channels, 1, 1, act='relu', name='f_attn')
|
||||||
|
|
||||||
|
def _cal_fweight(self, f, shape):
|
||||||
|
f_theta, f_phi, f_g = f
|
||||||
|
#flatten
|
||||||
|
f_theta = paddle.transpose(f_theta, [0, 2, 3, 1])
|
||||||
|
f_theta = paddle.reshape(f_theta, [shape[0] * shape[1], shape[2], 128])
|
||||||
|
f_phi = paddle.transpose(f_phi, [0, 2, 3, 1])
|
||||||
|
f_phi = paddle.reshape(f_phi, [shape[0] * shape[1], shape[2], 128])
|
||||||
|
f_g = paddle.transpose(f_g, [0, 2, 3, 1])
|
||||||
|
f_g = paddle.reshape(f_g, [shape[0] * shape[1], shape[2], 128])
|
||||||
|
#correlation
|
||||||
|
f_attn = paddle.matmul(f_theta, paddle.transpose(f_phi, [0, 2, 1]))
|
||||||
|
#scale
|
||||||
|
f_attn = f_attn / (128**0.5)
|
||||||
|
f_attn = F.softmax(f_attn)
|
||||||
|
#weighted sum
|
||||||
|
f_weight = paddle.matmul(f_attn, f_g)
|
||||||
|
f_weight = paddle.reshape(
|
||||||
|
f_weight, [shape[0], shape[1], shape[2], 128])
|
||||||
|
return f_weight
|
||||||
|
|
||||||
|
def forward(self, f_common):
|
||||||
|
f_shape = paddle.shape(f_common)
|
||||||
|
# print('f_shape: ', f_shape)
|
||||||
|
|
||||||
|
f_theta = self.theta_conv(f_common)
|
||||||
|
f_phi = self.phi_conv(f_common)
|
||||||
|
f_g = self.g_conv(f_common)
|
||||||
|
|
||||||
|
######## horizon ########
|
||||||
|
fh_weight = self._cal_fweight([f_theta, f_phi, f_g],
|
||||||
|
[f_shape[0], f_shape[2], f_shape[3]])
|
||||||
|
fh_weight = paddle.transpose(fh_weight, [0, 3, 1, 2])
|
||||||
|
fh_weight = self.fh_weight_conv(fh_weight)
|
||||||
|
#short cut
|
||||||
|
fh_sc = self.fh_sc_conv(f_common)
|
||||||
|
f_h = F.relu(fh_weight + fh_sc)
|
||||||
|
|
||||||
|
######## vertical ########
|
||||||
|
fv_theta = paddle.transpose(f_theta, [0, 1, 3, 2])
|
||||||
|
fv_phi = paddle.transpose(f_phi, [0, 1, 3, 2])
|
||||||
|
fv_g = paddle.transpose(f_g, [0, 1, 3, 2])
|
||||||
|
fv_weight = self._cal_fweight([fv_theta, fv_phi, fv_g],
|
||||||
|
[f_shape[0], f_shape[3], f_shape[2]])
|
||||||
|
fv_weight = paddle.transpose(fv_weight, [0, 3, 2, 1])
|
||||||
|
fv_weight = self.fv_weight_conv(fv_weight)
|
||||||
|
#short cut
|
||||||
|
fv_sc = self.fv_sc_conv(f_common)
|
||||||
|
f_v = F.relu(fv_weight + fv_sc)
|
||||||
|
|
||||||
|
######## merge ########
|
||||||
|
f_attn = paddle.concat([f_h, f_v], axis=1)
|
||||||
|
f_attn = self.f_attn_conv(f_attn)
|
||||||
|
return f_attn
|
||||||
|
|
||||||
|
|
||||||
|
class SASTFPN(nn.Layer):
|
||||||
|
def __init__(self, in_channels, with_cab=False, **kwargs):
|
||||||
|
super(SASTFPN, self).__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.with_cab = with_cab
|
||||||
|
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
|
||||||
|
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
|
||||||
|
self.out_channels = 128
|
||||||
|
self.cross_attention = Cross_Attention(self.out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
#down fpn
|
||||||
|
f_down = self.FPN_Down_Fusion(x)
|
||||||
|
|
||||||
|
#up fpn
|
||||||
|
f_up = self.FPN_Up_Fusion(x)
|
||||||
|
|
||||||
|
#fusion
|
||||||
|
f_common = paddle.add(x=f_down, y=f_up)
|
||||||
|
f_common = F.relu(f_common)
|
||||||
|
|
||||||
|
if self.with_cab:
|
||||||
|
# print('enhence f_common with CAB.')
|
||||||
|
f_common = self.cross_attention(f_common)
|
||||||
|
|
||||||
|
return f_common
|
|
@ -24,11 +24,13 @@ __all__ = ['build_post_process']
|
||||||
|
|
||||||
def build_post_process(config, global_config=None):
|
def build_post_process(config, global_config=None):
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess
|
||||||
|
from .east_postprocess import EASTPostProcess
|
||||||
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
|
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -0,0 +1,141 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .locality_aware_nms import nms_locality
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
# __dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# sys.path.append(__dir__)
|
||||||
|
# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
|
|
||||||
|
class EASTPostProcess(object):
|
||||||
|
"""
|
||||||
|
The post process for EAST.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
score_thresh=0.8,
|
||||||
|
cover_thresh=0.1,
|
||||||
|
nms_thresh=0.2,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.score_thresh = score_thresh
|
||||||
|
self.cover_thresh = cover_thresh
|
||||||
|
self.nms_thresh = nms_thresh
|
||||||
|
|
||||||
|
# c++ la-nms is faster, but only support python 3.5
|
||||||
|
self.is_python35 = False
|
||||||
|
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||||
|
self.is_python35 = True
|
||||||
|
|
||||||
|
def restore_rectangle_quad(self, origin, geometry):
|
||||||
|
"""
|
||||||
|
Restore rectangle from quadrangle.
|
||||||
|
"""
|
||||||
|
# quad
|
||||||
|
origin_concat = np.concatenate(
|
||||||
|
(origin, origin, origin, origin), axis=1) # (n, 8)
|
||||||
|
pred_quads = origin_concat - geometry
|
||||||
|
pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
|
||||||
|
return pred_quads
|
||||||
|
|
||||||
|
def detect(self,
|
||||||
|
score_map,
|
||||||
|
geo_map,
|
||||||
|
score_thresh=0.8,
|
||||||
|
cover_thresh=0.1,
|
||||||
|
nms_thresh=0.2):
|
||||||
|
"""
|
||||||
|
restore text boxes from score map and geo map
|
||||||
|
"""
|
||||||
|
score_map = score_map[0]
|
||||||
|
geo_map = np.swapaxes(geo_map, 1, 0)
|
||||||
|
geo_map = np.swapaxes(geo_map, 1, 2)
|
||||||
|
# filter the score map
|
||||||
|
xy_text = np.argwhere(score_map > score_thresh)
|
||||||
|
if len(xy_text) == 0:
|
||||||
|
return []
|
||||||
|
# sort the text boxes via the y axis
|
||||||
|
xy_text = xy_text[np.argsort(xy_text[:, 0])]
|
||||||
|
#restore quad proposals
|
||||||
|
text_box_restored = self.restore_rectangle_quad(
|
||||||
|
xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
|
||||||
|
boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
|
||||||
|
boxes[:, :8] = text_box_restored.reshape((-1, 8))
|
||||||
|
boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
|
||||||
|
if self.is_python35:
|
||||||
|
import lanms
|
||||||
|
boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
|
||||||
|
else:
|
||||||
|
boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
|
||||||
|
if boxes.shape[0] == 0:
|
||||||
|
return []
|
||||||
|
# Here we filter some low score boxes by the average score map,
|
||||||
|
# this is different from the orginal paper.
|
||||||
|
for i, box in enumerate(boxes):
|
||||||
|
mask = np.zeros_like(score_map, dtype=np.uint8)
|
||||||
|
cv2.fillPoly(mask, box[:8].reshape(
|
||||||
|
(-1, 4, 2)).astype(np.int32) // 4, 1)
|
||||||
|
boxes[i, 8] = cv2.mean(score_map, mask)[0]
|
||||||
|
boxes = boxes[boxes[:, 8] > cover_thresh]
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
def sort_poly(self, p):
|
||||||
|
"""
|
||||||
|
Sort polygons.
|
||||||
|
"""
|
||||||
|
min_axis = np.argmin(np.sum(p, axis=1))
|
||||||
|
p = p[[min_axis, (min_axis + 1) % 4,\
|
||||||
|
(min_axis + 2) % 4, (min_axis + 3) % 4]]
|
||||||
|
if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
|
||||||
|
return p
|
||||||
|
else:
|
||||||
|
return p[[0, 3, 2, 1]]
|
||||||
|
|
||||||
|
def __call__(self, outs_dict, shape_list):
|
||||||
|
score_list = outs_dict['f_score']
|
||||||
|
geo_list = outs_dict['f_geo']
|
||||||
|
img_num = len(shape_list)
|
||||||
|
dt_boxes_list = []
|
||||||
|
for ino in range(img_num):
|
||||||
|
score = score_list[ino].numpy()
|
||||||
|
geo = geo_list[ino].numpy()
|
||||||
|
boxes = self.detect(
|
||||||
|
score_map=score,
|
||||||
|
geo_map=geo,
|
||||||
|
score_thresh=self.score_thresh,
|
||||||
|
cover_thresh=self.cover_thresh,
|
||||||
|
nms_thresh=self.nms_thresh)
|
||||||
|
boxes_norm = []
|
||||||
|
if len(boxes) > 0:
|
||||||
|
h, w = score.shape[1:]
|
||||||
|
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||||
|
boxes = boxes[:, :8].reshape((-1, 4, 2))
|
||||||
|
boxes[:, :, 0] /= ratio_w
|
||||||
|
boxes[:, :, 1] /= ratio_h
|
||||||
|
for i_box, box in enumerate(boxes):
|
||||||
|
box = self.sort_poly(box.astype(np.int32))
|
||||||
|
if np.linalg.norm(box[0] - box[1]) < 5 \
|
||||||
|
or np.linalg.norm(box[3] - box[0]) < 5:
|
||||||
|
continue
|
||||||
|
boxes_norm.append(box)
|
||||||
|
dt_boxes_list.append({'points': np.array(boxes_norm)})
|
||||||
|
return dt_boxes_list
|
|
@ -0,0 +1,199 @@
|
||||||
|
"""
|
||||||
|
Locality aware nms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from shapely.geometry import Polygon
|
||||||
|
|
||||||
|
|
||||||
|
def intersection(g, p):
|
||||||
|
"""
|
||||||
|
Intersection.
|
||||||
|
"""
|
||||||
|
g = Polygon(g[:8].reshape((4, 2)))
|
||||||
|
p = Polygon(p[:8].reshape((4, 2)))
|
||||||
|
g = g.buffer(0)
|
||||||
|
p = p.buffer(0)
|
||||||
|
if not g.is_valid or not p.is_valid:
|
||||||
|
return 0
|
||||||
|
inter = Polygon(g).intersection(Polygon(p)).area
|
||||||
|
union = g.area + p.area - inter
|
||||||
|
if union == 0:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return inter / union
|
||||||
|
|
||||||
|
|
||||||
|
def intersection_iog(g, p):
|
||||||
|
"""
|
||||||
|
Intersection_iog.
|
||||||
|
"""
|
||||||
|
g = Polygon(g[:8].reshape((4, 2)))
|
||||||
|
p = Polygon(p[:8].reshape((4, 2)))
|
||||||
|
if not g.is_valid or not p.is_valid:
|
||||||
|
return 0
|
||||||
|
inter = Polygon(g).intersection(Polygon(p)).area
|
||||||
|
#union = g.area + p.area - inter
|
||||||
|
union = p.area
|
||||||
|
if union == 0:
|
||||||
|
print("p_area is very small")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return inter / union
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_merge(g, p):
|
||||||
|
"""
|
||||||
|
Weighted merge.
|
||||||
|
"""
|
||||||
|
g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
|
||||||
|
g[8] = (g[8] + p[8])
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def standard_nms(S, thres):
|
||||||
|
"""
|
||||||
|
Standard nms.
|
||||||
|
"""
|
||||||
|
order = np.argsort(S[:, 8])[::-1]
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
||||||
|
|
||||||
|
inds = np.where(ovr <= thres)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return S[keep]
|
||||||
|
|
||||||
|
|
||||||
|
def standard_nms_inds(S, thres):
|
||||||
|
"""
|
||||||
|
Standard nms, retun inds.
|
||||||
|
"""
|
||||||
|
order = np.argsort(S[:, 8])[::-1]
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
||||||
|
|
||||||
|
inds = np.where(ovr <= thres)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def nms(S, thres):
|
||||||
|
"""
|
||||||
|
nms.
|
||||||
|
"""
|
||||||
|
order = np.argsort(S[:, 8])[::-1]
|
||||||
|
keep = []
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
|
||||||
|
|
||||||
|
inds = np.where(ovr <= thres)[0]
|
||||||
|
order = order[inds + 1]
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
|
||||||
|
"""
|
||||||
|
soft_nms
|
||||||
|
:para boxes_in, N x 9 (coords + score)
|
||||||
|
:para threshould, eliminate cases min score(0.001)
|
||||||
|
:para Nt_thres, iou_threshi
|
||||||
|
:para sigma, gaussian weght
|
||||||
|
:method, linear or gaussian
|
||||||
|
"""
|
||||||
|
boxes = boxes_in.copy()
|
||||||
|
N = boxes.shape[0]
|
||||||
|
if N is None or N < 1:
|
||||||
|
return np.array([])
|
||||||
|
pos, maxpos = 0, 0
|
||||||
|
weight = 0.0
|
||||||
|
inds = np.arange(N)
|
||||||
|
tbox, sbox = boxes[0].copy(), boxes[0].copy()
|
||||||
|
for i in range(N):
|
||||||
|
maxscore = boxes[i, 8]
|
||||||
|
maxpos = i
|
||||||
|
tbox = boxes[i].copy()
|
||||||
|
ti = inds[i]
|
||||||
|
pos = i + 1
|
||||||
|
#get max box
|
||||||
|
while pos < N:
|
||||||
|
if maxscore < boxes[pos, 8]:
|
||||||
|
maxscore = boxes[pos, 8]
|
||||||
|
maxpos = pos
|
||||||
|
pos = pos + 1
|
||||||
|
#add max box as a detection
|
||||||
|
boxes[i, :] = boxes[maxpos, :]
|
||||||
|
inds[i] = inds[maxpos]
|
||||||
|
#swap
|
||||||
|
boxes[maxpos, :] = tbox
|
||||||
|
inds[maxpos] = ti
|
||||||
|
tbox = boxes[i].copy()
|
||||||
|
pos = i + 1
|
||||||
|
#NMS iteration
|
||||||
|
while pos < N:
|
||||||
|
sbox = boxes[pos].copy()
|
||||||
|
ts_iou_val = intersection(tbox, sbox)
|
||||||
|
if ts_iou_val > 0:
|
||||||
|
if method == 1:
|
||||||
|
if ts_iou_val > Nt_thres:
|
||||||
|
weight = 1 - ts_iou_val
|
||||||
|
else:
|
||||||
|
weight = 1
|
||||||
|
elif method == 2:
|
||||||
|
weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
|
||||||
|
else:
|
||||||
|
if ts_iou_val > Nt_thres:
|
||||||
|
weight = 0
|
||||||
|
else:
|
||||||
|
weight = 1
|
||||||
|
boxes[pos, 8] = weight * boxes[pos, 8]
|
||||||
|
#if box score falls below thresold, discard the box by
|
||||||
|
#swaping last box update N
|
||||||
|
if boxes[pos, 8] < threshold:
|
||||||
|
boxes[pos, :] = boxes[N - 1, :]
|
||||||
|
inds[pos] = inds[N - 1]
|
||||||
|
N = N - 1
|
||||||
|
pos = pos - 1
|
||||||
|
pos = pos + 1
|
||||||
|
|
||||||
|
return boxes[:N]
|
||||||
|
|
||||||
|
|
||||||
|
def nms_locality(polys, thres=0.3):
|
||||||
|
"""
|
||||||
|
locality aware nms of EAST
|
||||||
|
:param polys: a N*9 numpy array. first 8 coordinates, then prob
|
||||||
|
:return: boxes after nms
|
||||||
|
"""
|
||||||
|
S = []
|
||||||
|
p = None
|
||||||
|
for g in polys:
|
||||||
|
if p is not None and intersection(g, p) > thres:
|
||||||
|
p = weighted_merge(g, p)
|
||||||
|
else:
|
||||||
|
if p is not None:
|
||||||
|
S.append(p)
|
||||||
|
p = g
|
||||||
|
if p is not None:
|
||||||
|
S.append(p)
|
||||||
|
|
||||||
|
if len(S) == 0:
|
||||||
|
return np.array([])
|
||||||
|
return standard_nms(np.array(S), thres)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 343,350,448,135,474,143,369,359
|
||||||
|
print(
|
||||||
|
Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]]))
|
||||||
|
.area)
|
|
@ -23,14 +23,16 @@ class BaseRecLabelDecode(object):
|
||||||
character_dict_path=None,
|
character_dict_path=None,
|
||||||
character_type='ch',
|
character_type='ch',
|
||||||
use_space_char=False):
|
use_space_char=False):
|
||||||
support_character_type = ['ch', 'en', 'en_sensitive']
|
support_character_type = [
|
||||||
|
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
|
||||||
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, self.character_str)
|
support_character_type, self.character_str)
|
||||||
|
|
||||||
if character_type == "en":
|
if character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type == "ch":
|
elif character_type in ["ch", "french", "german", "japan", "korean"]:
|
||||||
self.character_str = ""
|
self.character_str = ""
|
||||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
||||||
with open(character_dict_path, "rb") as fin:
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
|
|
@ -0,0 +1,295 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..'))
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from .locality_aware_nms import nms_locality
|
||||||
|
# import lanms
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class SASTPostProcess(object):
|
||||||
|
"""
|
||||||
|
The post process for SAST.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
score_thresh=0.5,
|
||||||
|
nms_thresh=0.2,
|
||||||
|
sample_pts_num=2,
|
||||||
|
shrink_ratio_of_width=0.3,
|
||||||
|
expand_scale=1.0,
|
||||||
|
tcl_map_thresh=0.5,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.score_thresh = score_thresh
|
||||||
|
self.nms_thresh = nms_thresh
|
||||||
|
self.sample_pts_num = sample_pts_num
|
||||||
|
self.shrink_ratio_of_width = shrink_ratio_of_width
|
||||||
|
self.expand_scale = expand_scale
|
||||||
|
self.tcl_map_thresh = tcl_map_thresh
|
||||||
|
|
||||||
|
# c++ la-nms is faster, but only support python 3.5
|
||||||
|
self.is_python35 = False
|
||||||
|
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||||
|
self.is_python35 = True
|
||||||
|
|
||||||
|
def point_pair2poly(self, point_pair_list):
|
||||||
|
"""
|
||||||
|
Transfer vertical point_pairs into poly point in clockwise.
|
||||||
|
"""
|
||||||
|
# constract poly
|
||||||
|
point_num = len(point_pair_list) * 2
|
||||||
|
point_list = [0] * point_num
|
||||||
|
for idx, point_pair in enumerate(point_pair_list):
|
||||||
|
point_list[idx] = point_pair[0]
|
||||||
|
point_list[point_num - 1 - idx] = point_pair[1]
|
||||||
|
return np.array(point_list).reshape(-1, 2)
|
||||||
|
|
||||||
|
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||||
|
"""
|
||||||
|
Generate shrink_quad_along_width.
|
||||||
|
"""
|
||||||
|
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||||
|
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||||
|
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||||
|
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||||
|
|
||||||
|
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
||||||
|
"""
|
||||||
|
expand poly along width.
|
||||||
|
"""
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||||
|
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||||
|
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||||
|
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||||
|
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||||
|
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
|
||||||
|
right_ratio = 1.0 + \
|
||||||
|
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||||
|
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||||
|
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||||
|
poly[0] = left_quad_expand[0]
|
||||||
|
poly[-1] = left_quad_expand[-1]
|
||||||
|
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||||
|
poly[point_num // 2] = right_quad_expand[2]
|
||||||
|
return poly
|
||||||
|
|
||||||
|
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
||||||
|
"""Restore quad."""
|
||||||
|
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||||
|
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||||
|
|
||||||
|
# Sort the text boxes via the y axis
|
||||||
|
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
||||||
|
|
||||||
|
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||||
|
scores = scores[:, np.newaxis]
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
point_num = int(tvo_map.shape[-1] / 2)
|
||||||
|
assert point_num == 4
|
||||||
|
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
||||||
|
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||||
|
quads = xy_text_tile - tvo_map
|
||||||
|
|
||||||
|
return scores, quads, xy_text
|
||||||
|
|
||||||
|
def quad_area(self, quad):
|
||||||
|
"""
|
||||||
|
compute area of a quad.
|
||||||
|
"""
|
||||||
|
edge = [
|
||||||
|
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||||
|
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||||
|
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||||
|
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
|
||||||
|
]
|
||||||
|
return np.sum(edge) / 2.
|
||||||
|
|
||||||
|
def nms(self, dets):
|
||||||
|
if self.is_python35:
|
||||||
|
import lanms
|
||||||
|
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
|
||||||
|
else:
|
||||||
|
dets = nms_locality(dets, self.nms_thresh)
|
||||||
|
return dets
|
||||||
|
|
||||||
|
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
|
||||||
|
"""
|
||||||
|
Cluster pixels in tcl_map based on quads.
|
||||||
|
"""
|
||||||
|
instance_count = quads.shape[0] + 1 # contain background
|
||||||
|
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
||||||
|
if instance_count == 1:
|
||||||
|
return instance_count, instance_label_map
|
||||||
|
|
||||||
|
# predict text center
|
||||||
|
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||||
|
n = xy_text.shape[0]
|
||||||
|
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||||
|
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||||
|
pred_tc = xy_text - tco
|
||||||
|
|
||||||
|
# get gt text center
|
||||||
|
m = quads.shape[0]
|
||||||
|
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||||
|
|
||||||
|
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
||||||
|
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||||
|
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||||
|
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||||
|
|
||||||
|
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
||||||
|
return instance_count, instance_label_map
|
||||||
|
|
||||||
|
def estimate_sample_pts_num(self, quad, xy_text):
|
||||||
|
"""
|
||||||
|
Estimate sample points number.
|
||||||
|
"""
|
||||||
|
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||||
|
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||||
|
|
||||||
|
dense_sample_pts_num = max(2, int(ew))
|
||||||
|
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
|
||||||
|
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||||
|
|
||||||
|
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
||||||
|
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
||||||
|
|
||||||
|
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
||||||
|
return sample_pts_num
|
||||||
|
|
||||||
|
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
|
||||||
|
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
|
||||||
|
"""
|
||||||
|
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
||||||
|
"""
|
||||||
|
# restore quad
|
||||||
|
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
||||||
|
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
||||||
|
dets = self.nms(dets)
|
||||||
|
if dets.shape[0] == 0:
|
||||||
|
return []
|
||||||
|
quads = dets[:, :-1].reshape(-1, 4, 2)
|
||||||
|
|
||||||
|
# Compute quad area
|
||||||
|
quad_areas = []
|
||||||
|
for quad in quads:
|
||||||
|
quad_areas.append(-self.quad_area(quad))
|
||||||
|
|
||||||
|
# instance segmentation
|
||||||
|
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
||||||
|
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
|
||||||
|
|
||||||
|
# restore single poly with tcl instance.
|
||||||
|
poly_list = []
|
||||||
|
for instance_idx in range(1, instance_count):
|
||||||
|
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
|
||||||
|
quad = quads[instance_idx - 1]
|
||||||
|
q_area = quad_areas[instance_idx - 1]
|
||||||
|
if q_area < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
#
|
||||||
|
len1 = float(np.linalg.norm(quad[0] -quad[1]))
|
||||||
|
len2 = float(np.linalg.norm(quad[1] -quad[2]))
|
||||||
|
min_len = min(len1, len2)
|
||||||
|
if min_len < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# filter small CC
|
||||||
|
if xy_text.shape[0] <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# filter low confidence instance
|
||||||
|
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||||
|
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
||||||
|
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# sort xy_text
|
||||||
|
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||||
|
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||||
|
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||||
|
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||||
|
proj_unit_vec = (right_center_pt - left_center_pt) / \
|
||||||
|
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||||
|
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
||||||
|
xy_text = xy_text[np.argsort(proj_value)]
|
||||||
|
|
||||||
|
# Sample pts in tcl map
|
||||||
|
if self.sample_pts_num == 0:
|
||||||
|
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
||||||
|
else:
|
||||||
|
sample_pts_num = self.sample_pts_num
|
||||||
|
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
|
||||||
|
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||||
|
|
||||||
|
point_pair_list = []
|
||||||
|
for x, y in xy_center_line:
|
||||||
|
# get corresponding offset
|
||||||
|
offset = tbo_map[y, x, :].reshape(2, 2)
|
||||||
|
if offset_expand != 1.0:
|
||||||
|
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
||||||
|
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
|
||||||
|
offset_detal = offset / offset_length * expand_length
|
||||||
|
offset = offset + offset_detal
|
||||||
|
# original point
|
||||||
|
ori_yx = np.array([y, x], dtype=np.float32)
|
||||||
|
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
||||||
|
point_pair_list.append(point_pair)
|
||||||
|
|
||||||
|
# ndarry: (x, 2), expand poly along width
|
||||||
|
detected_poly = self.point_pair2poly(point_pair_list)
|
||||||
|
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
|
||||||
|
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||||
|
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||||
|
poly_list.append(detected_poly)
|
||||||
|
|
||||||
|
return poly_list
|
||||||
|
|
||||||
|
def __call__(self, outs_dict, shape_list):
|
||||||
|
score_list = outs_dict['f_score']
|
||||||
|
border_list = outs_dict['f_border']
|
||||||
|
tvo_list = outs_dict['f_tvo']
|
||||||
|
tco_list = outs_dict['f_tco']
|
||||||
|
|
||||||
|
img_num = len(shape_list)
|
||||||
|
poly_lists = []
|
||||||
|
for ino in range(img_num):
|
||||||
|
p_score = score_list[ino].transpose((1,2,0)).numpy()
|
||||||
|
p_border = border_list[ino].transpose((1,2,0)).numpy()
|
||||||
|
p_tvo = tvo_list[ino].transpose((1,2,0)).numpy()
|
||||||
|
p_tco = tco_list[ino].transpose((1,2,0)).numpy()
|
||||||
|
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||||
|
|
||||||
|
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||||
|
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||||
|
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
|
||||||
|
poly_lists.append({'points': np.array(poly_list)})
|
||||||
|
|
||||||
|
return poly_lists
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -32,7 +32,7 @@ setup(
|
||||||
package_dir={'paddleocr': ''},
|
package_dir={'paddleocr': ''},
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
|
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
|
||||||
version='0.0.3',
|
version='2.0',
|
||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
license='Apache License 2.0',
|
license='Apache License 2.0',
|
||||||
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from tools.infer.utility import draw_ocr_box_txt
|
from tools.infer.utility import draw_ocr_box_txt
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class TextSystem(object):
|
class TextSystem(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.text_detector = predict_det.TextDetector(args)
|
self.text_detector = predict_det.TextDetector(args)
|
||||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||||
self.use_angle_cls = args.use_angle_cls
|
self.use_angle_cls = args.use_angle_cls
|
||||||
|
self.drop_score = args.drop_score
|
||||||
if self.use_angle_cls:
|
if self.use_angle_cls:
|
||||||
self.text_classifier = predict_cls.TextClassifier(args)
|
self.text_classifier = predict_cls.TextClassifier(args)
|
||||||
|
|
||||||
|
@ -81,7 +85,8 @@ class TextSystem(object):
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
dt_boxes, elapse = self.text_detector(img)
|
dt_boxes, elapse = self.text_detector(img)
|
||||||
logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
|
logger.info("dt_boxes num : {}, elapse : {}".format(
|
||||||
|
len(dt_boxes), elapse))
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
img_crop_list = []
|
img_crop_list = []
|
||||||
|
@ -99,9 +104,16 @@ class TextSystem(object):
|
||||||
len(img_crop_list), elapse))
|
len(img_crop_list), elapse))
|
||||||
|
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
logger.info("rec_res num : {}, elapse : {}".format(
|
||||||
|
len(rec_res), elapse))
|
||||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||||
return dt_boxes, rec_res
|
filter_boxes, filter_rec_res = [], []
|
||||||
|
for box, rec_reuslt in zip(dt_boxes, rec_res):
|
||||||
|
text, score = rec_reuslt
|
||||||
|
if score >= self.drop_score:
|
||||||
|
filter_boxes.append(box)
|
||||||
|
filter_rec_res.append(rec_reuslt)
|
||||||
|
return filter_boxes, filter_rec_res
|
||||||
|
|
||||||
|
|
||||||
def sorted_boxes(dt_boxes):
|
def sorted_boxes(dt_boxes):
|
||||||
|
@ -143,12 +155,8 @@ def main(args):
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
||||||
|
|
||||||
dt_num = len(dt_boxes)
|
for text, score in rec_res:
|
||||||
for dno in range(dt_num):
|
logger.info("{}, {:.3f}".format(text, score))
|
||||||
text, score = rec_res[dno]
|
|
||||||
if score >= drop_score:
|
|
||||||
text_str = "%s, %.3f" % (text, score)
|
|
||||||
logger.info(text_str)
|
|
||||||
|
|
||||||
if is_visualize:
|
if is_visualize:
|
||||||
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
|
@ -174,5 +182,4 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger = get_logger()
|
|
||||||
main(utility.parse_args())
|
main(utility.parse_args())
|
Loading…
Reference in New Issue