Merge remote-tracking branch 'origin/dygraph' into dygraph
# Conflicts: # PPOCRLabel/libs/resources.py
This commit is contained in:
commit
b063c417ea
|
@ -401,6 +401,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
|
||||
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
|
||||
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
|
||||
showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
|
||||
|
||||
zoom = QWidgetAction(self)
|
||||
zoom.setDefaultWidget(self.zoomWidget)
|
||||
|
@ -568,7 +569,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
addActions(self.menus.file,
|
||||
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||
|
||||
addActions(self.menus.help, (showSteps, showInfo))
|
||||
addActions(self.menus.help, (showKeys,showSteps, showInfo))
|
||||
addActions(self.menus.view, (
|
||||
self.displayLabelOption, self.labelDialogOption,
|
||||
None,
|
||||
|
@ -763,6 +764,10 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
msg = stepsInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def showKeysDialog(self):
|
||||
msg = keysInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def createShape(self):
|
||||
assert self.beginner()
|
||||
self.canvas.setEditing(False)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -174,6 +174,7 @@ def stepsInfo(lang='en'):
|
|||
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
|
||||
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
|
||||
"识别标签保存在*rec_gt.txt*中。\n"
|
||||
|
||||
else:
|
||||
msg = "1. Build and launch using the instructions above.\n" \
|
||||
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
|
||||
|
@ -188,4 +189,56 @@ def stepsInfo(lang='en'):
|
|||
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
|
||||
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
|
||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||
|
||||
return msg
|
||||
|
||||
def keysInfo(lang='en'):
|
||||
if lang == 'ch':
|
||||
msg = "快捷键\t\t\t说明\n" \
|
||||
"———————————————————————\n"\
|
||||
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
|
||||
"W\t\t\t新建矩形框\n" \
|
||||
"Q\t\t\t新建四点框\n" \
|
||||
"Ctrl + E\t\t编辑所选框标签\n" \
|
||||
"Ctrl + R\t\t重新识别所选标记\n" \
|
||||
"Ctrl + C\t\t复制并粘贴选中的标记框\n" \
|
||||
"Ctrl + 鼠标左键\t\t多选标记框\n" \
|
||||
"Backspace\t\t删除所选框\n" \
|
||||
"Ctrl + V\t\t确认本张图片标记\n" \
|
||||
"Ctrl + Shift + d\t删除本张图片\n" \
|
||||
"D\t\t\t下一张图片\n" \
|
||||
"A\t\t\t上一张图片\n" \
|
||||
"Ctrl++\t\t\t缩小\n" \
|
||||
"Ctrl--\t\t\t放大\n" \
|
||||
"↑→↓←\t\t\t移动标记框\n" \
|
||||
"———————————————————————\n" \
|
||||
"注:Mac用户Command键替换上述Ctrl键"
|
||||
|
||||
else:
|
||||
msg = "Shortcut Keys\t\tDescription\n" \
|
||||
"———————————————————————\n" \
|
||||
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
|
||||
"\t\t\tof the current image\n" \
|
||||
"\n"\
|
||||
"W\t\t\tCreate a rect box\n" \
|
||||
"Q\t\t\tCreate a four-points box\n" \
|
||||
"Ctrl + E\t\tEdit label of the selected box\n" \
|
||||
"Ctrl + R\t\tRe-recognize the selected box\n" \
|
||||
"Ctrl + C\t\tCopy and paste the selected\n" \
|
||||
"\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Ctrl + Left Mouse\tMulti select the label\n" \
|
||||
"Button\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Backspace\t\tDelete the selected box\n" \
|
||||
"Ctrl + V\t\tCheck image\n" \
|
||||
"Ctrl + Shift + d\tDelete image\n" \
|
||||
"D\t\t\tNext image\n" \
|
||||
"A\t\t\tPrevious image\n" \
|
||||
"Ctrl++\t\t\tZoom in\n" \
|
||||
"Ctrl--\t\t\tZoom out\n" \
|
||||
"↑→↓←\t\t\tMove selected box" \
|
||||
"———————————————————————\n" \
|
||||
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
|
||||
|
||||
return msg
|
|
@ -90,6 +90,7 @@ saveRec=保存识别结果
|
|||
tempLabel=待识别
|
||||
nullLabel=无法识别
|
||||
steps=操作步骤
|
||||
keys=快捷键
|
||||
choseModelLg=选择模型语言
|
||||
cancel=取消
|
||||
ok=确认
|
||||
|
|
|
@ -90,6 +90,7 @@ saveRec=Save Recognition Result
|
|||
tempLabel=TEMPORARY
|
||||
nullLabel=NULL
|
||||
steps=Steps
|
||||
keys=Shortcut Keys
|
||||
choseModelLg=Choose Model Language
|
||||
cancel=Cancel
|
||||
ok=OK
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- ["Student2", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2", "Teacher"]
|
||||
# key: maps
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,174 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,176 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
|
|||
|
||||
set(DEMO_NAME "ocr_system")
|
||||
|
||||
|
||||
macro(safe_set_static_flag)
|
||||
foreach(flag_var
|
||||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
|
||||
|
|
|
@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
|
|||
//------------------------------------------------------------------------------
|
||||
|
||||
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
||||
std::memset(e, 0, sizeof(TEdge));
|
||||
std::memset(e, int(0), sizeof(TEdge));
|
||||
e->Next = eNext;
|
||||
e->Prev = ePrev;
|
||||
e->Curr = Pt;
|
||||
|
@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
|
|||
TEdge *rb = lm->RightBound;
|
||||
|
||||
OutPt *Op1 = 0;
|
||||
if (!lb) {
|
||||
if (!lb || !rb) {
|
||||
// nb: don't insert LB into either AEL or SEL
|
||||
InsertEdgeIntoAEL(rb, 0);
|
||||
SetWindingCount(*rb);
|
||||
if (IsContributing(*rb))
|
||||
Op1 = AddOutPt(rb, rb->Bot);
|
||||
} else if (!rb) {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
SetWindingCount(*lb);
|
||||
if (IsContributing(*lb))
|
||||
Op1 = AddOutPt(lb, lb->Bot);
|
||||
//} else if (!rb) {
|
||||
// InsertEdgeIntoAEL(lb, 0);
|
||||
// SetWindingCount(*lb);
|
||||
// if (IsContributing(*lb))
|
||||
// Op1 = AddOutPt(lb, lb->Bot);
|
||||
InsertScanbeam(lb->Top.Y);
|
||||
} else {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
|
@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
if (dir == dLeftToRight) {
|
||||
maxIt = m_Maxima.begin();
|
||||
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
||||
maxIt = m_Maxima.end();
|
||||
} else {
|
||||
maxRit = m_Maxima.rbegin();
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
||||
maxRit = m_Maxima.rend();
|
||||
}
|
||||
|
@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
}
|
||||
} else {
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -21,10 +21,10 @@ std::vector<std::string> OCRConfig::split(const std::string &str,
|
|||
std::vector<std::string> res;
|
||||
if ("" == str)
|
||||
return res;
|
||||
char *strs = new char[str.length() + 1];
|
||||
char strs[str.length() + 1];
|
||||
std::strcpy(strs, str.c_str());
|
||||
|
||||
char *d = new char[delim.length() + 1];
|
||||
char d[delim.length() + 1];
|
||||
std::strcpy(d, delim.c_str());
|
||||
|
||||
char *p = std::strtok(strs, d);
|
||||
|
|
|
@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
|
|||
|
||||
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
如果想使用CPU进行预测,执行命令如下
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB文本检测模型推理"></a>
|
||||
|
|
|
@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
|
|||
|
||||
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
If you want to use the CPU for prediction, execute the command as follows
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB_DETECTION"></a>
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.5.30 Provide Lightweight Chinese OCR online experience
|
||||
- 2020.5.30 Model prediction and training support on Windows system
|
||||
- 2020.5.30 Open source general Chinese OCR model
|
||||
|
|
BIN
doc/joinus.PNG
BIN
doc/joinus.PNG
Binary file not shown.
Before Width: | Height: | Size: 188 KiB After Width: | Height: | Size: 189 KiB |
|
@ -46,6 +46,7 @@ class SimpleDataSet(Dataset):
|
|||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.check_data()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if self.mode == "train" and self.do_shuffle:
|
||||
self.shuffle_data_random()
|
||||
|
@ -102,16 +103,8 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
data = self.data_lines[file_idx]
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
|
@ -120,8 +113,8 @@ class SimpleDataSet(Dataset):
|
|||
except:
|
||||
error_meg = traceback.format_exc()
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, error_meg))
|
||||
"When parsing file {} and label {}, error happened with msg: {}".format(
|
||||
data['img_path'],data['label'], error_meg))
|
||||
outs = None
|
||||
if outs is None:
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
|
@ -132,3 +125,17 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
|
||||
def check_data(self):
|
||||
new_data_lines = []
|
||||
for data_line in self.data_lines:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
if os.path.exists(img_path):
|
||||
new_data_lines.append({'img_path': img_path, 'label': label})
|
||||
else:
|
||||
self.logger.info("{} does not exist!".format(img_path))
|
||||
self.data_lines = new_data_lines
|
|
@ -54,6 +54,27 @@ class CELoss(nn.Layer):
|
|||
return loss
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1,2])
|
||||
elif reduction=="none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1,2])
|
||||
|
||||
return loss
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
|
@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
|
|||
else:
|
||||
self.act = None
|
||||
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
if len(out1.shape) < 2:
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
else:
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import paddle.nn as nn
|
|||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
|
@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
|
|||
|
||||
def forward(self, input, batch, **kargs):
|
||||
loss_dict = {}
|
||||
loss_all = 0.
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
weight = self.loss_weight[idx]
|
||||
loss = {
|
||||
"{}_{}".format(key, idx): loss[key] * weight
|
||||
for key in loss
|
||||
}
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
||||
|
|
|
@ -14,23 +14,76 @@
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
else:
|
||||
loss_dict["loss"] = 0.
|
||||
for k, value in loss_dict.items():
|
||||
if k == "loss":
|
||||
continue
|
||||
else:
|
||||
loss_dict["loss"] += value
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
||||
name="loss_dml"):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
key=None,
|
||||
maps_name=None,
|
||||
name="dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss):
|
|||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
|
||||
if self.maps_name is None:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||
idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
|
|||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="db",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.name = name
|
||||
self.key = None
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
|
||||
if isinstance(loss, dict):
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
continue
|
||||
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||
loss_dict[name] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDilaDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="dila_dbloss"):
|
||||
super().__init__()
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
stu_outs = predicts[pair[0]]
|
||||
tch_outs = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
stu_preds = stu_outs[self.key]
|
||||
tch_preds = tch_outs[self.key]
|
||||
|
||||
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||
|
||||
# dilation to teacher prediction
|
||||
dilation_w = np.array([[1, 1], [1, 1]])
|
||||
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||
for i in range(th_shrink_maps.shape[0]):
|
||||
dilate_maps[i] = cv2.dilate(
|
||||
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||
1:]
|
||||
|
||||
# calculate the shrink map loss
|
||||
bce_loss = self.alpha * self.bce_loss(
|
||||
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||
label_shrink_mask)
|
||||
|
||||
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||
loss_dict[k] = bce_loss + loss_binary_maps
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
|
|
@ -55,6 +55,7 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
|
|
|
@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
|
|||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name="RecMetric",
|
||||
main_indicator='acc',
|
||||
base_metric_name=None,
|
||||
main_indicator=None,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
|
@ -42,16 +42,13 @@ class DistillationMetric(object):
|
|||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, *args, **kwargs):
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
|
|
|
@ -79,7 +79,10 @@ class BaseModel(nn.Layer):
|
|||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
x = self.head(x, targets=data)
|
||||
y["head_out"] = x
|
||||
if isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
else:
|
||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
|||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
init_model(model, path=pretrained)
|
||||
model = load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
|
|
|
@ -21,7 +21,7 @@ import copy
|
|||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
|
@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -187,3 +187,29 @@ class DBPostProcess(object):
|
|||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class DistillationDBPostProcess(object):
|
||||
def __init__(self, model_name=["student"],
|
||||
key=None,
|
||||
thresh=0.3,
|
||||
box_thresh=0.6,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.5,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
self.post_process = DBPostProcess(thresh=thresh,
|
||||
box_thresh=box_thresh,
|
||||
max_candidates=max_candidates,
|
||||
unclip_ratio=unclip_ratio,
|
||||
use_dilation=use_dilation,
|
||||
score_mode=score_mode)
|
||||
|
||||
def __call__(self, predicts, shape_list):
|
||||
results = {}
|
||||
for k in self.model_name:
|
||||
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
||||
return results
|
||||
|
|
|
@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer):
|
|||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
if path is None:
|
||||
return False
|
||||
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||
print(f"The pretrained_model {path} does not exists!")
|
||||
return False
|
||||
|
||||
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||
params = paddle.load(path)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
print(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
print(f"load pretrain successful from {path}")
|
||||
return model
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
model_name:ocr_rec
|
||||
python:python
|
||||
gpu_list:0|0,1
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:10
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:
|
||||
Global.use_gpu:
|
||||
Global.pretrained_model:null
|
||||
|
||||
trainer:norm|pact
|
||||
norm_train:tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
quant_train:deploy/slim/quantization/quant.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
|
||||
eval:tools/eval.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||
|
||||
Global.save_inference_dir:./output/
|
||||
Global.pretrained_model:
|
||||
norm_export:tools/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||
quant_export:deploy/slim/quantization/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
|
||||
inference:tools/infer/predict_rec.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:True|False
|
||||
--precision:fp32|fp16|int8
|
||||
--rec_model_dir:./inference/ch_ppocr_mobile_v2.0_rec_infer/
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
|
@ -29,19 +29,21 @@ train_model_list=$(func_parser_value "${lines[0]}")
|
|||
|
||||
trainer_list=$(func_parser_value "${lines[10]}")
|
||||
|
||||
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
|
||||
MODE=$2
|
||||
# prepare pretrained weights and dataset
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||
|
||||
if [ ${train_model_list[*]} = "ocr_det" ]; then
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||
fi
|
||||
if [ ${MODE} = "lite_train_infer" ];then
|
||||
# pretrain lite train data
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
|
||||
cd ./train_data/ && tar xf icdar2015_lite.tar
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
|
||||
|
||||
cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
|
||||
ln -s ./icdar2015_lite ./icdar2015
|
||||
cd ../
|
||||
epoch=10
|
||||
|
@ -49,13 +51,15 @@ if [ ${MODE} = "lite_train_infer" ];then
|
|||
elif [ ${MODE} = "whole_train_infer" ];then
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
|
||||
cd ./train_data/ && tar xf icdar2015.tar && cd ../
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
|
||||
cd ./train_data/ && tar xf icdar2015.tar && tar xf ic15_data.tar && cd ../
|
||||
epoch=500
|
||||
eval_batch_step=200
|
||||
elif [ ${MODE} = "whole_infer" ];then
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
|
||||
cd ./train_data/ && tar xf icdar2015_infer.tar
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
|
||||
cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar
|
||||
ln -s ./icdar2015_infer ./icdar2015
|
||||
cd ../
|
||||
epoch=10
|
||||
|
@ -88,9 +92,11 @@ for train_model in ${train_model_list[*]}; do
|
|||
elif [ ${train_model} = "ocr_rec" ];then
|
||||
model_name="ocr_rec"
|
||||
yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_rec_data_200.tar
|
||||
cd ./inference && tar xf ch_rec_data_200.tar && cd ../
|
||||
img_dir="./inference/ch_rec_data_200/"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar
|
||||
cd ./inference && tar xf rec_inference.tar && cd ../
|
||||
img_dir="./inference/rec_inference/"
|
||||
data_dir=./inference/rec_inference
|
||||
data_label_file=[./inference/rec_inference/rec_gt_test.txt]
|
||||
fi
|
||||
|
||||
# eval
|
||||
|
|
|
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
|||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
@ -55,7 +55,10 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
model_type = None
|
||||
|
||||
best_model_dict = init_model(config, model)
|
||||
if len(best_model_dict):
|
||||
|
@ -68,7 +71,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
eval_class, model_type, use_srn)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -112,7 +112,6 @@ class TextClassifier(object):
|
|||
if '180' in label and score > self.cls_thresh:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
elapse = time.time() - starttime
|
||||
return img_list, cls_res, elapse
|
||||
|
||||
|
||||
|
@ -146,7 +145,6 @@ def main(args):
|
|||
cls_res[ino]))
|
||||
logger.info(
|
||||
"The predict time about text angle classify module is as follows: ")
|
||||
text_classifier.cls_times.info(average=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -64,6 +64,24 @@ class TextRecognizer(object):
|
|||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
self.benchmark = args.benchmark
|
||||
if args.benchmark:
|
||||
import auto_log
|
||||
pid = os.getpid()
|
||||
self.autolog = auto_log.AutoLogger(
|
||||
model_name="rec",
|
||||
model_precision=args.precision,
|
||||
batch_size=args.rec_batch_num,
|
||||
data_shape="dynamic",
|
||||
save_path=args.save_log_path,
|
||||
inference_config=self.config,
|
||||
pids=pid,
|
||||
process_name=None,
|
||||
gpu_ids=0 if args.use_gpu else None,
|
||||
time_keys=[
|
||||
'preprocess_time', 'inference_time', 'postprocess_time'
|
||||
],
|
||||
warmup=10)
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
|
@ -168,6 +186,8 @@ class TextRecognizer(object):
|
|||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
if self.benchmark:
|
||||
self.autolog.times.start()
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
|
@ -196,6 +216,8 @@ class TextRecognizer(object):
|
|||
norm_img_batch.append(norm_img[0])
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
if self.rec_algorithm == "SRN":
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||
|
@ -222,6 +244,8 @@ class TextRecognizer(object):
|
|||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = {"predict": outputs[2]}
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
|
@ -231,11 +255,14 @@ class TextRecognizer(object):
|
|||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
return rec_res, time.time() - st
|
||||
|
||||
|
||||
|
@ -251,9 +278,6 @@ def main(args):
|
|||
for i in range(10):
|
||||
res = text_recognizer([img])
|
||||
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
count = 0
|
||||
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
|
@ -273,6 +297,8 @@ def main(args):
|
|||
for ino in range(len(img_list)):
|
||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||
rec_res[ino]))
|
||||
if args.benchmark:
|
||||
text_recognizer.autolog.report()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -24,9 +24,6 @@ from paddle import inference
|
|||
import time
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
|
|
|
@ -186,7 +186,10 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
model_type = None
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
|
|
@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer):
|
|||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
logger.info('valid dataloader has {} iters'.format(
|
||||
|
|
Loading…
Reference in New Issue