Merge remote-tracking branch 'origin/dygraph' into dygraph
# Conflicts: # PPOCRLabel/libs/resources.py 改为大写
This commit is contained in:
commit
f6d2bc9be5
|
@ -401,6 +401,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
|
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
|
||||||
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
|
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
|
||||||
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
|
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
|
||||||
|
showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
|
||||||
|
|
||||||
zoom = QWidgetAction(self)
|
zoom = QWidgetAction(self)
|
||||||
zoom.setDefaultWidget(self.zoomWidget)
|
zoom.setDefaultWidget(self.zoomWidget)
|
||||||
|
@ -568,7 +569,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
addActions(self.menus.file,
|
addActions(self.menus.file,
|
||||||
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||||
|
|
||||||
addActions(self.menus.help, (showSteps, showInfo))
|
addActions(self.menus.help, (showKeys,showSteps, showInfo))
|
||||||
addActions(self.menus.view, (
|
addActions(self.menus.view, (
|
||||||
self.displayLabelOption, self.labelDialogOption,
|
self.displayLabelOption, self.labelDialogOption,
|
||||||
None,
|
None,
|
||||||
|
@ -763,6 +764,10 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
msg = stepsInfo(self.lang)
|
msg = stepsInfo(self.lang)
|
||||||
QMessageBox.information(self, u'Information', msg)
|
QMessageBox.information(self, u'Information', msg)
|
||||||
|
|
||||||
|
def showKeysDialog(self):
|
||||||
|
msg = keysInfo(self.lang)
|
||||||
|
QMessageBox.information(self, u'Information', msg)
|
||||||
|
|
||||||
def createShape(self):
|
def createShape(self):
|
||||||
assert self.beginner()
|
assert self.beginner()
|
||||||
self.canvas.setEditing(False)
|
self.canvas.setEditing(False)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -174,6 +174,7 @@ def stepsInfo(lang='en'):
|
||||||
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
|
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
|
||||||
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
|
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
|
||||||
"识别标签保存在*rec_gt.txt*中。\n"
|
"识别标签保存在*rec_gt.txt*中。\n"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
msg = "1. Build and launch using the instructions above.\n" \
|
msg = "1. Build and launch using the instructions above.\n" \
|
||||||
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
|
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
|
||||||
|
@ -188,4 +189,56 @@ def stepsInfo(lang='en'):
|
||||||
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
|
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
|
||||||
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
|
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
|
||||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def keysInfo(lang='en'):
|
||||||
|
if lang == 'ch':
|
||||||
|
msg = "快捷键\t\t\t说明\n" \
|
||||||
|
"———————————————————————\n"\
|
||||||
|
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
|
||||||
|
"W\t\t\t新建矩形框\n" \
|
||||||
|
"Q\t\t\t新建四点框\n" \
|
||||||
|
"Ctrl + E\t\t编辑所选框标签\n" \
|
||||||
|
"Ctrl + R\t\t重新识别所选标记\n" \
|
||||||
|
"Ctrl + C\t\t复制并粘贴选中的标记框\n" \
|
||||||
|
"Ctrl + 鼠标左键\t\t多选标记框\n" \
|
||||||
|
"Backspace\t\t删除所选框\n" \
|
||||||
|
"Ctrl + V\t\t确认本张图片标记\n" \
|
||||||
|
"Ctrl + Shift + d\t删除本张图片\n" \
|
||||||
|
"D\t\t\t下一张图片\n" \
|
||||||
|
"A\t\t\t上一张图片\n" \
|
||||||
|
"Ctrl++\t\t\t缩小\n" \
|
||||||
|
"Ctrl--\t\t\t放大\n" \
|
||||||
|
"↑→↓←\t\t\t移动标记框\n" \
|
||||||
|
"———————————————————————\n" \
|
||||||
|
"注:Mac用户Command键替换上述Ctrl键"
|
||||||
|
|
||||||
|
else:
|
||||||
|
msg = "Shortcut Keys\t\tDescription\n" \
|
||||||
|
"———————————————————————\n" \
|
||||||
|
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
|
||||||
|
"\t\t\tof the current image\n" \
|
||||||
|
"\n"\
|
||||||
|
"W\t\t\tCreate a rect box\n" \
|
||||||
|
"Q\t\t\tCreate a four-points box\n" \
|
||||||
|
"Ctrl + E\t\tEdit label of the selected box\n" \
|
||||||
|
"Ctrl + R\t\tRe-recognize the selected box\n" \
|
||||||
|
"Ctrl + C\t\tCopy and paste the selected\n" \
|
||||||
|
"\t\t\tbox\n" \
|
||||||
|
"\n"\
|
||||||
|
"Ctrl + Left Mouse\tMulti select the label\n" \
|
||||||
|
"Button\t\t\tbox\n" \
|
||||||
|
"\n"\
|
||||||
|
"Backspace\t\tDelete the selected box\n" \
|
||||||
|
"Ctrl + V\t\tCheck image\n" \
|
||||||
|
"Ctrl + Shift + d\tDelete image\n" \
|
||||||
|
"D\t\t\tNext image\n" \
|
||||||
|
"A\t\t\tPrevious image\n" \
|
||||||
|
"Ctrl++\t\t\tZoom in\n" \
|
||||||
|
"Ctrl--\t\t\tZoom out\n" \
|
||||||
|
"↑→↓←\t\t\tMove selected box" \
|
||||||
|
"———————————————————————\n" \
|
||||||
|
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
|
||||||
|
|
||||||
return msg
|
return msg
|
|
@ -90,6 +90,7 @@ saveRec=保存识别结果
|
||||||
tempLabel=待识别
|
tempLabel=待识别
|
||||||
nullLabel=无法识别
|
nullLabel=无法识别
|
||||||
steps=操作步骤
|
steps=操作步骤
|
||||||
|
keys=快捷键
|
||||||
choseModelLg=选择模型语言
|
choseModelLg=选择模型语言
|
||||||
cancel=取消
|
cancel=取消
|
||||||
ok=确认
|
ok=确认
|
||||||
|
|
|
@ -3,7 +3,7 @@ openFileDetail=Open image or label file
|
||||||
quit=Quit
|
quit=Quit
|
||||||
quitApp=Quit application
|
quitApp=Quit application
|
||||||
openDir=Open Dir
|
openDir=Open Dir
|
||||||
openDatasetDir=open DatasetDir
|
openDatasetDir=Open DatasetDir
|
||||||
copyPrevBounding=Copy previous Bounding Boxes in the current image
|
copyPrevBounding=Copy previous Bounding Boxes in the current image
|
||||||
changeSavedAnnotationDir=Change default saved Annotation dir
|
changeSavedAnnotationDir=Change default saved Annotation dir
|
||||||
openAnnotation=Open Annotation
|
openAnnotation=Open Annotation
|
||||||
|
@ -90,6 +90,7 @@ saveRec=Save Recognition Result
|
||||||
tempLabel=TEMPORARY
|
tempLabel=TEMPORARY
|
||||||
nullLabel=NULL
|
nullLabel=NULL
|
||||||
steps=Steps
|
steps=Steps
|
||||||
|
keys=Shortcut Keys
|
||||||
choseModelLg=Choose Model Language
|
choseModelLg=Choose Model Language
|
||||||
cancel=Cancel
|
cancel=Cancel
|
||||||
ok=OK
|
ok=OK
|
||||||
|
|
|
@ -0,0 +1,202 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 1200
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/ch_db_mv3/
|
||||||
|
save_epoch_step: 1200
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [3000, 2000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_en/img_10.jpg
|
||||||
|
save_res_path: ./output/det_db/predicts_db.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
name: DistillationModel
|
||||||
|
algorithm: Distillation
|
||||||
|
Models:
|
||||||
|
Student:
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Student2:
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Teacher:
|
||||||
|
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||||
|
freeze_params: true
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 18
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 256
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss
|
||||||
|
loss_config_list:
|
||||||
|
- DistillationDilaDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
- ["Student2", "Teacher"]
|
||||||
|
key: maps
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Student2"]
|
||||||
|
maps_name: "thrink_maps"
|
||||||
|
weight: 1.0
|
||||||
|
# act: None
|
||||||
|
model_name_pairs: ["Student", "Student2"]
|
||||||
|
key: maps
|
||||||
|
- DistillationDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Student2"]
|
||||||
|
# key: maps
|
||||||
|
# name: DBLoss
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
warmup_epoch: 2
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationDBPostProcess
|
||||||
|
model_name: ["Student", "Student2", "Teacher"]
|
||||||
|
# key: maps
|
||||||
|
thresh: 0.3
|
||||||
|
box_thresh: 0.6
|
||||||
|
max_candidates: 1000
|
||||||
|
unclip_ratio: 1.5
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DistillationMetric
|
||||||
|
base_metric_name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
key: "Student"
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||||
|
ratio_list: [1.0]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- IaaAugment:
|
||||||
|
augmenter_args:
|
||||||
|
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||||
|
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||||
|
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||||
|
- EastRandomCropData:
|
||||||
|
size: [960, 960]
|
||||||
|
max_tries: 50
|
||||||
|
keep_ratio: true
|
||||||
|
- MakeBorderMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
thresh_min: 0.3
|
||||||
|
thresh_max: 0.7
|
||||||
|
- MakeShrinkMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
min_text_size: 8
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
# image_shape: [736, 1280]
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -0,0 +1,174 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 1200
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/ch_db_mv3/
|
||||||
|
save_epoch_step: 1200
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [3000, 2000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_en/img_10.jpg
|
||||||
|
save_res_path: ./output/det_db/predicts_db.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
name: DistillationModel
|
||||||
|
algorithm: Distillation
|
||||||
|
Models:
|
||||||
|
Student:
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Teacher:
|
||||||
|
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||||
|
freeze_params: true
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 18
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 256
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss
|
||||||
|
loss_config_list:
|
||||||
|
- DistillationDilaDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
key: maps
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
- DistillationDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Teacher"]
|
||||||
|
# key: maps
|
||||||
|
name: DBLoss
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
warmup_epoch: 2
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationDBPostProcess
|
||||||
|
model_name: ["Student", "Student2"]
|
||||||
|
key: head_out
|
||||||
|
thresh: 0.3
|
||||||
|
box_thresh: 0.6
|
||||||
|
max_candidates: 1000
|
||||||
|
unclip_ratio: 1.5
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DistillationMetric
|
||||||
|
base_metric_name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
key: "Student"
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||||
|
ratio_list: [1.0]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- IaaAugment:
|
||||||
|
augmenter_args:
|
||||||
|
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||||
|
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||||
|
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||||
|
- EastRandomCropData:
|
||||||
|
size: [960, 960]
|
||||||
|
max_tries: 50
|
||||||
|
keep_ratio: true
|
||||||
|
- MakeBorderMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
thresh_min: 0.3
|
||||||
|
thresh_max: 0.7
|
||||||
|
- MakeShrinkMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
min_text_size: 8
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
# image_shape: [736, 1280]
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -0,0 +1,176 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 1200
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 2
|
||||||
|
save_model_dir: ./output/ch_db_mv3/
|
||||||
|
save_epoch_step: 1200
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [3000, 2000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_en/img_10.jpg
|
||||||
|
save_res_path: ./output/det_db/predicts_db.txt
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
name: DistillationModel
|
||||||
|
algorithm: Distillation
|
||||||
|
Models:
|
||||||
|
Student:
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Student2:
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss
|
||||||
|
loss_config_list:
|
||||||
|
- DistillationDMLLoss:
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Student2"]
|
||||||
|
maps_name: "thrink_maps"
|
||||||
|
weight: 1.0
|
||||||
|
act: "softmax"
|
||||||
|
model_name_pairs: ["Student", "Student2"]
|
||||||
|
key: maps
|
||||||
|
- DistillationDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Student2"]
|
||||||
|
# key: maps
|
||||||
|
name: DBLoss
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.001
|
||||||
|
warmup_epoch: 2
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationDBPostProcess
|
||||||
|
model_name: ["Student", "Student2"]
|
||||||
|
key: head_out
|
||||||
|
thresh: 0.3
|
||||||
|
box_thresh: 0.6
|
||||||
|
max_candidates: 1000
|
||||||
|
unclip_ratio: 1.5
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: DistillationMetric
|
||||||
|
base_metric_name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
key: "Student"
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||||
|
ratio_list: [1.0]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- IaaAugment:
|
||||||
|
augmenter_args:
|
||||||
|
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||||
|
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||||
|
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||||
|
- EastRandomCropData:
|
||||||
|
size: [960, 960]
|
||||||
|
max_tries: 50
|
||||||
|
keep_ratio: true
|
||||||
|
- MakeBorderMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
thresh_min: 0.3
|
||||||
|
thresh_max: 0.7
|
||||||
|
- MakeShrinkMap:
|
||||||
|
shrink_ratio: 0.4
|
||||||
|
min_text_size: 8
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/icdar2015/text_localization/
|
||||||
|
label_file_list:
|
||||||
|
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- DetLabelEncode: # Class handling label
|
||||||
|
- DetResizeForTest:
|
||||||
|
# image_shape: [736, 1280]
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 1 # must be 1
|
||||||
|
num_workers: 2
|
|
@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
|
||||||
|
|
||||||
set(DEMO_NAME "ocr_system")
|
set(DEMO_NAME "ocr_system")
|
||||||
|
|
||||||
|
|
||||||
macro(safe_set_static_flag)
|
macro(safe_set_static_flag)
|
||||||
foreach(flag_var
|
foreach(flag_var
|
||||||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
|
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
|
||||||
|
|
|
@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
||||||
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
||||||
std::memset(e, 0, sizeof(TEdge));
|
std::memset(e, int(0), sizeof(TEdge));
|
||||||
e->Next = eNext;
|
e->Next = eNext;
|
||||||
e->Prev = ePrev;
|
e->Prev = ePrev;
|
||||||
e->Curr = Pt;
|
e->Curr = Pt;
|
||||||
|
@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
|
||||||
TEdge *rb = lm->RightBound;
|
TEdge *rb = lm->RightBound;
|
||||||
|
|
||||||
OutPt *Op1 = 0;
|
OutPt *Op1 = 0;
|
||||||
if (!lb) {
|
if (!lb || !rb) {
|
||||||
// nb: don't insert LB into either AEL or SEL
|
// nb: don't insert LB into either AEL or SEL
|
||||||
InsertEdgeIntoAEL(rb, 0);
|
InsertEdgeIntoAEL(rb, 0);
|
||||||
SetWindingCount(*rb);
|
SetWindingCount(*rb);
|
||||||
if (IsContributing(*rb))
|
if (IsContributing(*rb))
|
||||||
Op1 = AddOutPt(rb, rb->Bot);
|
Op1 = AddOutPt(rb, rb->Bot);
|
||||||
} else if (!rb) {
|
//} else if (!rb) {
|
||||||
InsertEdgeIntoAEL(lb, 0);
|
// InsertEdgeIntoAEL(lb, 0);
|
||||||
SetWindingCount(*lb);
|
// SetWindingCount(*lb);
|
||||||
if (IsContributing(*lb))
|
// if (IsContributing(*lb))
|
||||||
Op1 = AddOutPt(lb, lb->Bot);
|
// Op1 = AddOutPt(lb, lb->Bot);
|
||||||
InsertScanbeam(lb->Top.Y);
|
InsertScanbeam(lb->Top.Y);
|
||||||
} else {
|
} else {
|
||||||
InsertEdgeIntoAEL(lb, 0);
|
InsertEdgeIntoAEL(lb, 0);
|
||||||
|
@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
||||||
if (dir == dLeftToRight) {
|
if (dir == dLeftToRight) {
|
||||||
maxIt = m_Maxima.begin();
|
maxIt = m_Maxima.begin();
|
||||||
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
||||||
maxIt++;
|
++maxIt;
|
||||||
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
||||||
maxIt = m_Maxima.end();
|
maxIt = m_Maxima.end();
|
||||||
} else {
|
} else {
|
||||||
maxRit = m_Maxima.rbegin();
|
maxRit = m_Maxima.rbegin();
|
||||||
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
||||||
maxRit++;
|
++maxRit;
|
||||||
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
||||||
maxRit = m_Maxima.rend();
|
maxRit = m_Maxima.rend();
|
||||||
}
|
}
|
||||||
|
@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
||||||
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
||||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||||
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
||||||
maxIt++;
|
++maxIt;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
||||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||||
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
||||||
maxRit++;
|
++maxRit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -21,10 +21,10 @@ std::vector<std::string> OCRConfig::split(const std::string &str,
|
||||||
std::vector<std::string> res;
|
std::vector<std::string> res;
|
||||||
if ("" == str)
|
if ("" == str)
|
||||||
return res;
|
return res;
|
||||||
char *strs = new char[str.length() + 1];
|
char strs[str.length() + 1];
|
||||||
std::strcpy(strs, str.c_str());
|
std::strcpy(strs, str.c_str());
|
||||||
|
|
||||||
char *d = new char[delim.length() + 1];
|
char d[delim.length() + 1];
|
||||||
std::strcpy(d, delim.c_str());
|
std::strcpy(d, delim.c_str());
|
||||||
|
|
||||||
char *p = std::strtok(strs, d);
|
char *p = std::strtok(strs, d);
|
||||||
|
|
|
@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
|
||||||
|
|
||||||
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||||
```
|
```
|
||||||
|
|
||||||
如果想使用CPU进行预测,执行命令如下
|
如果想使用CPU进行预测,执行命令如下
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="DB文本检测模型推理"></a>
|
<a name="DB文本检测模型推理"></a>
|
||||||
|
|
|
@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
|
||||||
|
|
||||||
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to use the CPU for prediction, execute the command as follows
|
If you want to use the CPU for prediction, execute the command as follows
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="DB_DETECTION"></a>
|
<a name="DB_DETECTION"></a>
|
||||||
|
|
|
@ -15,8 +15,6 @@
|
||||||
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
|
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
|
||||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
|
||||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
|
||||||
- 2020.5.30 Provide Lightweight Chinese OCR online experience
|
- 2020.5.30 Provide Lightweight Chinese OCR online experience
|
||||||
- 2020.5.30 Model prediction and training support on Windows system
|
- 2020.5.30 Model prediction and training support on Windows system
|
||||||
- 2020.5.30 Open source general Chinese OCR model
|
- 2020.5.30 Open source general Chinese OCR model
|
||||||
|
|
BIN
doc/joinus.PNG
BIN
doc/joinus.PNG
Binary file not shown.
Before Width: | Height: | Size: 188 KiB After Width: | Height: | Size: 189 KiB |
|
@ -46,6 +46,7 @@ class SimpleDataSet(Dataset):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||||
|
self.check_data()
|
||||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
if self.mode == "train" and self.do_shuffle:
|
if self.mode == "train" and self.do_shuffle:
|
||||||
self.shuffle_data_random()
|
self.shuffle_data_random()
|
||||||
|
@ -102,16 +103,8 @@ class SimpleDataSet(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
file_idx = self.data_idx_order_list[idx]
|
file_idx = self.data_idx_order_list[idx]
|
||||||
data_line = self.data_lines[file_idx]
|
data = self.data_lines[file_idx]
|
||||||
try:
|
try:
|
||||||
data_line = data_line.decode('utf-8')
|
|
||||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
|
||||||
file_name = substr[0]
|
|
||||||
label = substr[1]
|
|
||||||
img_path = os.path.join(self.data_dir, file_name)
|
|
||||||
data = {'img_path': img_path, 'label': label}
|
|
||||||
if not os.path.exists(img_path):
|
|
||||||
raise Exception("{} does not exist!".format(img_path))
|
|
||||||
with open(data['img_path'], 'rb') as f:
|
with open(data['img_path'], 'rb') as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
data['image'] = img
|
data['image'] = img
|
||||||
|
@ -120,8 +113,8 @@ class SimpleDataSet(Dataset):
|
||||||
except:
|
except:
|
||||||
error_meg = traceback.format_exc()
|
error_meg = traceback.format_exc()
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
"When parsing line {}, error happened with msg: {}".format(
|
"When parsing file {} and label {}, error happened with msg: {}".format(
|
||||||
data_line, error_meg))
|
data['img_path'],data['label'], error_meg))
|
||||||
outs = None
|
outs = None
|
||||||
if outs is None:
|
if outs is None:
|
||||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||||
|
@ -132,3 +125,17 @@ class SimpleDataSet(Dataset):
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data_idx_order_list)
|
return len(self.data_idx_order_list)
|
||||||
|
|
||||||
|
def check_data(self):
|
||||||
|
new_data_lines = []
|
||||||
|
for data_line in self.data_lines:
|
||||||
|
data_line = data_line.decode('utf-8')
|
||||||
|
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||||
|
file_name = substr[0]
|
||||||
|
label = substr[1]
|
||||||
|
img_path = os.path.join(self.data_dir, file_name)
|
||||||
|
if os.path.exists(img_path):
|
||||||
|
new_data_lines.append({'img_path': img_path, 'label': label})
|
||||||
|
else:
|
||||||
|
self.logger.info("{} does not exist!".format(img_path))
|
||||||
|
self.data_lines = new_data_lines
|
|
@ -54,6 +54,27 @@ class CELoss(nn.Layer):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class KLJSLoss(object):
|
||||||
|
def __init__(self, mode='kl'):
|
||||||
|
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, p1, p2, reduction="mean"):
|
||||||
|
|
||||||
|
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||||
|
|
||||||
|
if self.mode.lower() == "js":
|
||||||
|
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||||
|
loss *= 0.5
|
||||||
|
if reduction == "mean":
|
||||||
|
loss = paddle.mean(loss, axis=[1,2])
|
||||||
|
elif reduction=="none" or reduction is None:
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
loss = paddle.sum(loss, axis=[1,2])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
class DMLLoss(nn.Layer):
|
class DMLLoss(nn.Layer):
|
||||||
"""
|
"""
|
||||||
DMLLoss
|
DMLLoss
|
||||||
|
@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
|
||||||
else:
|
else:
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
|
self.jskl_loss = KLJSLoss(mode="js")
|
||||||
|
|
||||||
def forward(self, out1, out2):
|
def forward(self, out1, out2):
|
||||||
if self.act is not None:
|
if self.act is not None:
|
||||||
out1 = self.act(out1)
|
out1 = self.act(out1)
|
||||||
out2 = self.act(out2)
|
out2 = self.act(out2)
|
||||||
|
if len(out1.shape) < 2:
|
||||||
log_out1 = paddle.log(out1)
|
log_out1 = paddle.log(out1)
|
||||||
log_out2 = paddle.log(out2)
|
log_out2 = paddle.log(out2)
|
||||||
loss = (F.kl_div(
|
loss = (F.kl_div(
|
||||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||||
log_out2, out1, reduction='batchmean')) / 2.0
|
log_out2, out1, reduction='batchmean')) / 2.0
|
||||||
|
else:
|
||||||
|
loss = self.jskl_loss(out1, out2)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ import paddle.nn as nn
|
||||||
|
|
||||||
from .distillation_loss import DistillationCTCLoss
|
from .distillation_loss import DistillationCTCLoss
|
||||||
from .distillation_loss import DistillationDMLLoss
|
from .distillation_loss import DistillationDMLLoss
|
||||||
from .distillation_loss import DistillationDistanceLoss
|
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||||
|
|
||||||
|
|
||||||
class CombinedLoss(nn.Layer):
|
class CombinedLoss(nn.Layer):
|
||||||
|
@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
|
||||||
|
|
||||||
def forward(self, input, batch, **kargs):
|
def forward(self, input, batch, **kargs):
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
|
loss_all = 0.
|
||||||
for idx, loss_func in enumerate(self.loss_func):
|
for idx, loss_func in enumerate(self.loss_func):
|
||||||
loss = loss_func(input, batch, **kargs)
|
loss = loss_func(input, batch, **kargs)
|
||||||
if isinstance(loss, paddle.Tensor):
|
if isinstance(loss, paddle.Tensor):
|
||||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||||
weight = self.loss_weight[idx]
|
weight = self.loss_weight[idx]
|
||||||
loss = {
|
for key in loss.keys():
|
||||||
"{}_{}".format(key, idx): loss[key] * weight
|
if key == "loss":
|
||||||
for key in loss
|
loss_all += loss[key] * weight
|
||||||
}
|
else:
|
||||||
loss_dict.update(loss)
|
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
loss_dict["loss"] = loss_all
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
|
@ -14,23 +14,76 @@
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
from .basic_loss import DMLLoss
|
from .basic_loss import DMLLoss
|
||||||
from .basic_loss import DistanceLoss
|
from .basic_loss import DistanceLoss
|
||||||
|
from .det_db_loss import DBLoss
|
||||||
|
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||||
|
|
||||||
|
|
||||||
|
def _sum_loss(loss_dict):
|
||||||
|
if "loss" in loss_dict.keys():
|
||||||
|
return loss_dict
|
||||||
|
else:
|
||||||
|
loss_dict["loss"] = 0.
|
||||||
|
for k, value in loss_dict.items():
|
||||||
|
if k == "loss":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
loss_dict["loss"] += value
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
class DistillationDMLLoss(DMLLoss):
|
class DistillationDMLLoss(DMLLoss):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
def __init__(self,
|
||||||
name="loss_dml"):
|
model_name_pairs=[],
|
||||||
|
act=None,
|
||||||
|
key=None,
|
||||||
|
maps_name=None,
|
||||||
|
name="dml"):
|
||||||
super().__init__(act=act)
|
super().__init__(act=act)
|
||||||
assert isinstance(model_name_pairs, list)
|
assert isinstance(model_name_pairs, list)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name_pairs = model_name_pairs
|
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.maps_name = self._check_maps_name(maps_name)
|
||||||
|
|
||||||
|
def _check_model_name_pairs(self, model_name_pairs):
|
||||||
|
if not isinstance(model_name_pairs, list):
|
||||||
|
return []
|
||||||
|
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||||
|
return model_name_pairs
|
||||||
|
else:
|
||||||
|
return [model_name_pairs]
|
||||||
|
|
||||||
|
def _check_maps_name(self, maps_name):
|
||||||
|
if maps_name is None:
|
||||||
|
return None
|
||||||
|
elif type(maps_name) == str:
|
||||||
|
return [maps_name]
|
||||||
|
elif type(maps_name) == list:
|
||||||
|
return [maps_name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _slice_out(self, outs):
|
||||||
|
new_outs = {}
|
||||||
|
for k in self.maps_name:
|
||||||
|
if k == "thrink_maps":
|
||||||
|
new_outs[k] = outs[:, 0, :, :]
|
||||||
|
elif k == "threshold_maps":
|
||||||
|
new_outs[k] = outs[:, 1, :, :]
|
||||||
|
elif k == "binary_maps":
|
||||||
|
new_outs[k] = outs[:, 2, :, :]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
return new_outs
|
||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
|
@ -40,6 +93,8 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
if self.key is not None:
|
if self.key is not None:
|
||||||
out1 = out1[self.key]
|
out1 = out1[self.key]
|
||||||
out2 = out2[self.key]
|
out2 = out2[self.key]
|
||||||
|
|
||||||
|
if self.maps_name is None:
|
||||||
loss = super().forward(out1, out2)
|
loss = super().forward(out1, out2)
|
||||||
if isinstance(loss, dict):
|
if isinstance(loss, dict):
|
||||||
for key in loss:
|
for key in loss:
|
||||||
|
@ -47,6 +102,21 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
idx)] = loss[key]
|
idx)] = loss[key]
|
||||||
else:
|
else:
|
||||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||||
|
else:
|
||||||
|
outs1 = self._slice_out(out1)
|
||||||
|
outs2 = self._slice_out(out2)
|
||||||
|
for _c, k in enumerate(outs1.keys()):
|
||||||
|
loss = super().forward(outs1[k], outs2[k])
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss:
|
||||||
|
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||||
|
0], pair[1], map_name, idx)] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||||
|
idx)] = loss
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDBLoss(DBLoss):
|
||||||
|
def __init__(self,
|
||||||
|
model_name_list=[],
|
||||||
|
balance_loss=True,
|
||||||
|
main_loss_type='DiceLoss',
|
||||||
|
alpha=5,
|
||||||
|
beta=10,
|
||||||
|
ohem_ratio=3,
|
||||||
|
eps=1e-6,
|
||||||
|
name="db",
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name_list = model_name_list
|
||||||
|
self.name = name
|
||||||
|
self.key = None
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = {}
|
||||||
|
for idx, model_name in enumerate(self.model_name_list):
|
||||||
|
out = predicts[model_name]
|
||||||
|
if self.key is not None:
|
||||||
|
out = out[self.key]
|
||||||
|
loss = super().forward(out, batch)
|
||||||
|
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
for key in loss.keys():
|
||||||
|
if key == "loss":
|
||||||
|
continue
|
||||||
|
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||||
|
loss_dict[name] = loss[key]
|
||||||
|
else:
|
||||||
|
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDilaDBLoss(DBLoss):
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
key=None,
|
||||||
|
balance_loss=True,
|
||||||
|
main_loss_type='DiceLoss',
|
||||||
|
alpha=5,
|
||||||
|
beta=10,
|
||||||
|
ohem_ratio=3,
|
||||||
|
eps=1e-6,
|
||||||
|
name="dila_dbloss"):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
self.name = name
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
stu_outs = predicts[pair[0]]
|
||||||
|
tch_outs = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
stu_preds = stu_outs[self.key]
|
||||||
|
tch_preds = tch_outs[self.key]
|
||||||
|
|
||||||
|
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||||
|
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||||
|
|
||||||
|
# dilation to teacher prediction
|
||||||
|
dilation_w = np.array([[1, 1], [1, 1]])
|
||||||
|
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||||
|
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||||
|
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||||
|
for i in range(th_shrink_maps.shape[0]):
|
||||||
|
dilate_maps[i] = cv2.dilate(
|
||||||
|
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||||
|
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||||
|
|
||||||
|
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||||
|
1:]
|
||||||
|
|
||||||
|
# calculate the shrink map loss
|
||||||
|
bce_loss = self.alpha * self.bce_loss(
|
||||||
|
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||||
|
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||||
|
label_shrink_mask)
|
||||||
|
|
||||||
|
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||||
|
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||||
|
loss_dict[k] = bce_loss + loss_binary_maps
|
||||||
|
|
||||||
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
class DistillationDistanceLoss(DistanceLoss):
|
class DistillationDistanceLoss(DistanceLoss):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -55,6 +55,7 @@ class DetMetric(object):
|
||||||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
|
|
||||||
def get_metric(self):
|
def get_metric(self):
|
||||||
"""
|
"""
|
||||||
return metrics {
|
return metrics {
|
||||||
|
|
|
@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
|
||||||
class DistillationMetric(object):
|
class DistillationMetric(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
key=None,
|
key=None,
|
||||||
base_metric_name="RecMetric",
|
base_metric_name=None,
|
||||||
main_indicator='acc',
|
main_indicator=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.main_indicator = main_indicator
|
self.main_indicator = main_indicator
|
||||||
self.key = key
|
self.key = key
|
||||||
|
@ -42,16 +42,13 @@ class DistillationMetric(object):
|
||||||
main_indicator=self.main_indicator, **self.kwargs)
|
main_indicator=self.main_indicator, **self.kwargs)
|
||||||
self.metrics[key].reset()
|
self.metrics[key].reset()
|
||||||
|
|
||||||
def __call__(self, preds, *args, **kwargs):
|
def __call__(self, preds, batch, **kwargs):
|
||||||
assert isinstance(preds, dict)
|
assert isinstance(preds, dict)
|
||||||
if self.metrics is None:
|
if self.metrics is None:
|
||||||
self._init_metrcis(preds)
|
self._init_metrcis(preds)
|
||||||
output = dict()
|
output = dict()
|
||||||
for key in preds:
|
for key in preds:
|
||||||
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
|
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||||
for sub_key in metric:
|
|
||||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_metric(self):
|
def get_metric(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -79,6 +79,9 @@ class BaseModel(nn.Layer):
|
||||||
x = self.neck(x)
|
x = self.neck(x)
|
||||||
y["neck_out"] = x
|
y["neck_out"] = x
|
||||||
x = self.head(x, targets=data)
|
x = self.head(x, targets=data)
|
||||||
|
if isinstance(x, dict):
|
||||||
|
y.update(x)
|
||||||
|
else:
|
||||||
y["head_out"] = x
|
y["head_out"] = x
|
||||||
if self.return_all_feats:
|
if self.return_all_feats:
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
||||||
from ppocr.modeling.necks import build_neck
|
from ppocr.modeling.necks import build_neck
|
||||||
from ppocr.modeling.heads import build_head
|
from ppocr.modeling.heads import build_head
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||||
|
|
||||||
__all__ = ['DistillationModel']
|
__all__ = ['DistillationModel']
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
||||||
pretrained = model_config.pop("pretrained")
|
pretrained = model_config.pop("pretrained")
|
||||||
model = BaseModel(model_config)
|
model = BaseModel(model_config)
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
init_model(model, path=pretrained)
|
model = load_pretrained_params(model, pretrained)
|
||||||
if freeze_params:
|
if freeze_params:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.trainable = False
|
param.trainable = False
|
||||||
|
|
|
@ -21,7 +21,7 @@ import copy
|
||||||
|
|
||||||
__all__ = ['build_post_process']
|
__all__ = ['build_post_process']
|
||||||
|
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||||
|
@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||||
|
'DistillationDBPostProcess'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -187,3 +187,29 @@ class DBPostProcess(object):
|
||||||
|
|
||||||
boxes_batch.append({'points': boxes})
|
boxes_batch.append({'points': boxes})
|
||||||
return boxes_batch
|
return boxes_batch
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDBPostProcess(object):
|
||||||
|
def __init__(self, model_name=["student"],
|
||||||
|
key=None,
|
||||||
|
thresh=0.3,
|
||||||
|
box_thresh=0.6,
|
||||||
|
max_candidates=1000,
|
||||||
|
unclip_ratio=1.5,
|
||||||
|
use_dilation=False,
|
||||||
|
score_mode="fast",
|
||||||
|
**kwargs):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.key = key
|
||||||
|
self.post_process = DBPostProcess(thresh=thresh,
|
||||||
|
box_thresh=box_thresh,
|
||||||
|
max_candidates=max_candidates,
|
||||||
|
unclip_ratio=unclip_ratio,
|
||||||
|
use_dilation=use_dilation,
|
||||||
|
score_mode=score_mode)
|
||||||
|
|
||||||
|
def __call__(self, predicts, shape_list):
|
||||||
|
results = {}
|
||||||
|
for k in self.model_name:
|
||||||
|
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
||||||
|
return results
|
||||||
|
|
|
@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer):
|
||||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def load_pretrained_params(model, path):
|
||||||
|
if path is None:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||||
|
print(f"The pretrained_model {path} does not exists!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||||
|
params = paddle.load(path)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
new_state_dict = {}
|
||||||
|
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||||
|
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||||
|
new_state_dict[k1] = params[k2]
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||||
|
)
|
||||||
|
model.set_state_dict(new_state_dict)
|
||||||
|
print(f"load pretrain successful from {path}")
|
||||||
|
return model
|
||||||
|
|
||||||
def save_model(model,
|
def save_model(model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
model_name:ocr_rec
|
||||||
|
python:python
|
||||||
|
gpu_list:0|0,1
|
||||||
|
Global.auto_cast:null
|
||||||
|
Global.epoch_num:10
|
||||||
|
Global.save_model_dir:./output/
|
||||||
|
Train.loader.batch_size_per_card:
|
||||||
|
Global.use_gpu:
|
||||||
|
Global.pretrained_model:null
|
||||||
|
|
||||||
|
trainer:norm|pact
|
||||||
|
norm_train:tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||||
|
quant_train:deploy/slim/quantization/quant.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
|
||||||
|
eval:tools/eval.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||||
|
|
||||||
|
Global.save_inference_dir:./output/
|
||||||
|
Global.pretrained_model:
|
||||||
|
norm_export:tools/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||||
|
quant_export:deploy/slim/quantization/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
|
||||||
|
inference:tools/infer/predict_rec.py
|
||||||
|
--use_gpu:True|False
|
||||||
|
--enable_mkldnn:True|False
|
||||||
|
--cpu_threads:1|6
|
||||||
|
--rec_batch_num:1
|
||||||
|
--use_tensorrt:True|False
|
||||||
|
--precision:fp32|fp16|int8
|
||||||
|
--rec_model_dir:./inference/ch_ppocr_mobile_v2.0_rec_infer/
|
||||||
|
--image_dir:./inference/rec_inference
|
||||||
|
--save_log_path:./test/output/
|
|
@ -29,19 +29,21 @@ train_model_list=$(func_parser_value "${lines[0]}")
|
||||||
|
|
||||||
trainer_list=$(func_parser_value "${lines[10]}")
|
trainer_list=$(func_parser_value "${lines[10]}")
|
||||||
|
|
||||||
|
|
||||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
|
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
|
||||||
MODE=$2
|
MODE=$2
|
||||||
# prepare pretrained weights and dataset
|
# prepare pretrained weights and dataset
|
||||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
if [ ${train_model_list[*]} = "ocr_det" ]; then
|
||||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||||
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||||
|
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||||
|
fi
|
||||||
if [ ${MODE} = "lite_train_infer" ];then
|
if [ ${MODE} = "lite_train_infer" ];then
|
||||||
# pretrain lite train data
|
# pretrain lite train data
|
||||||
rm -rf ./train_data/icdar2015
|
rm -rf ./train_data/icdar2015
|
||||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
|
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
|
||||||
cd ./train_data/ && tar xf icdar2015_lite.tar
|
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
|
||||||
|
|
||||||
|
cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
|
||||||
ln -s ./icdar2015_lite ./icdar2015
|
ln -s ./icdar2015_lite ./icdar2015
|
||||||
cd ../
|
cd ../
|
||||||
epoch=10
|
epoch=10
|
||||||
|
@ -49,13 +51,15 @@ if [ ${MODE} = "lite_train_infer" ];then
|
||||||
elif [ ${MODE} = "whole_train_infer" ];then
|
elif [ ${MODE} = "whole_train_infer" ];then
|
||||||
rm -rf ./train_data/icdar2015
|
rm -rf ./train_data/icdar2015
|
||||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
|
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
|
||||||
cd ./train_data/ && tar xf icdar2015.tar && cd ../
|
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
|
||||||
|
cd ./train_data/ && tar xf icdar2015.tar && tar xf ic15_data.tar && cd ../
|
||||||
epoch=500
|
epoch=500
|
||||||
eval_batch_step=200
|
eval_batch_step=200
|
||||||
elif [ ${MODE} = "whole_infer" ];then
|
elif [ ${MODE} = "whole_infer" ];then
|
||||||
rm -rf ./train_data/icdar2015
|
rm -rf ./train_data/icdar2015
|
||||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
|
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
|
||||||
cd ./train_data/ && tar xf icdar2015_infer.tar
|
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
|
||||||
|
cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar
|
||||||
ln -s ./icdar2015_infer ./icdar2015
|
ln -s ./icdar2015_infer ./icdar2015
|
||||||
cd ../
|
cd ../
|
||||||
epoch=10
|
epoch=10
|
||||||
|
@ -88,9 +92,11 @@ for train_model in ${train_model_list[*]}; do
|
||||||
elif [ ${train_model} = "ocr_rec" ];then
|
elif [ ${train_model} = "ocr_rec" ];then
|
||||||
model_name="ocr_rec"
|
model_name="ocr_rec"
|
||||||
yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml"
|
yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml"
|
||||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_rec_data_200.tar
|
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar
|
||||||
cd ./inference && tar xf ch_rec_data_200.tar && cd ../
|
cd ./inference && tar xf rec_inference.tar && cd ../
|
||||||
img_dir="./inference/ch_rec_data_200/"
|
img_dir="./inference/rec_inference/"
|
||||||
|
data_dir=./inference/rec_inference
|
||||||
|
data_label_file=[./inference/rec_inference/rec_gt_test.txt]
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
|
|
|
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
from ppocr.metrics import build_metric
|
from ppocr.metrics import build_metric
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||||
from ppocr.utils.utility import print_dict
|
from ppocr.utils.utility import print_dict
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
|
@ -55,7 +55,10 @@ def main():
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
if "model_type" in config['Architecture'].keys():
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
|
else:
|
||||||
|
model_type = None
|
||||||
|
|
||||||
best_model_dict = init_model(config, model)
|
best_model_dict = init_model(config, model)
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
|
|
|
@ -112,7 +112,6 @@ class TextClassifier(object):
|
||||||
if '180' in label and score > self.cls_thresh:
|
if '180' in label and score > self.cls_thresh:
|
||||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||||
img_list[indices[beg_img_no + rno]], 1)
|
img_list[indices[beg_img_no + rno]], 1)
|
||||||
elapse = time.time() - starttime
|
|
||||||
return img_list, cls_res, elapse
|
return img_list, cls_res, elapse
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,7 +145,6 @@ def main(args):
|
||||||
cls_res[ino]))
|
cls_res[ino]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"The predict time about text angle classify module is as follows: ")
|
"The predict time about text angle classify module is as follows: ")
|
||||||
text_classifier.cls_times.info(average=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -64,6 +64,24 @@ class TextRecognizer(object):
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
|
self.benchmark = args.benchmark
|
||||||
|
if args.benchmark:
|
||||||
|
import auto_log
|
||||||
|
pid = os.getpid()
|
||||||
|
self.autolog = auto_log.AutoLogger(
|
||||||
|
model_name="rec",
|
||||||
|
model_precision=args.precision,
|
||||||
|
batch_size=args.rec_batch_num,
|
||||||
|
data_shape="dynamic",
|
||||||
|
save_path=args.save_log_path,
|
||||||
|
inference_config=self.config,
|
||||||
|
pids=pid,
|
||||||
|
process_name=None,
|
||||||
|
gpu_ids=0 if args.use_gpu else None,
|
||||||
|
time_keys=[
|
||||||
|
'preprocess_time', 'inference_time', 'postprocess_time'
|
||||||
|
],
|
||||||
|
warmup=10)
|
||||||
|
|
||||||
def resize_norm_img(self, img, max_wh_ratio):
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
imgC, imgH, imgW = self.rec_image_shape
|
imgC, imgH, imgW = self.rec_image_shape
|
||||||
|
@ -168,6 +186,8 @@ class TextRecognizer(object):
|
||||||
rec_res = [['', 0.0]] * img_num
|
rec_res = [['', 0.0]] * img_num
|
||||||
batch_num = self.rec_batch_num
|
batch_num = self.rec_batch_num
|
||||||
st = time.time()
|
st = time.time()
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.start()
|
||||||
for beg_img_no in range(0, img_num, batch_num):
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
norm_img_batch = []
|
norm_img_batch = []
|
||||||
|
@ -196,6 +216,8 @@ class TextRecognizer(object):
|
||||||
norm_img_batch.append(norm_img[0])
|
norm_img_batch.append(norm_img[0])
|
||||||
norm_img_batch = np.concatenate(norm_img_batch)
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
norm_img_batch = norm_img_batch.copy()
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.stamp()
|
||||||
|
|
||||||
if self.rec_algorithm == "SRN":
|
if self.rec_algorithm == "SRN":
|
||||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||||
|
@ -222,6 +244,8 @@ class TextRecognizer(object):
|
||||||
for output_tensor in self.output_tensors:
|
for output_tensor in self.output_tensors:
|
||||||
output = output_tensor.copy_to_cpu()
|
output = output_tensor.copy_to_cpu()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.stamp()
|
||||||
preds = {"predict": outputs[2]}
|
preds = {"predict": outputs[2]}
|
||||||
else:
|
else:
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
|
@ -231,11 +255,14 @@ class TextRecognizer(object):
|
||||||
for output_tensor in self.output_tensors:
|
for output_tensor in self.output_tensors:
|
||||||
output = output_tensor.copy_to_cpu()
|
output = output_tensor.copy_to_cpu()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.stamp()
|
||||||
preds = outputs[0]
|
preds = outputs[0]
|
||||||
rec_result = self.postprocess_op(preds)
|
rec_result = self.postprocess_op(preds)
|
||||||
for rno in range(len(rec_result)):
|
for rno in range(len(rec_result)):
|
||||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||||
|
if self.benchmark:
|
||||||
|
self.autolog.times.end(stamp=True)
|
||||||
return rec_res, time.time() - st
|
return rec_res, time.time() - st
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,9 +278,6 @@ def main(args):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
res = text_recognizer([img])
|
res = text_recognizer([img])
|
||||||
|
|
||||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for image_file in image_file_list:
|
for image_file in image_file_list:
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
if not flag:
|
if not flag:
|
||||||
|
@ -273,6 +297,8 @@ def main(args):
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||||
rec_res[ino]))
|
rec_res[ino]))
|
||||||
|
if args.benchmark:
|
||||||
|
text_recognizer.autolog.report()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -24,9 +24,6 @@ from paddle import inference
|
||||||
import time
|
import time
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
return v.lower() in ("true", "t", "1")
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
|
|
|
@ -186,7 +186,10 @@ def train(config,
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
try:
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
|
except:
|
||||||
|
model_type = None
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
start_epoch = best_model_dict['start_epoch']
|
start_epoch = best_model_dict['start_epoch']
|
||||||
|
|
|
@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer):
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||||
|
|
||||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||||
if valid_dataloader is not None:
|
if valid_dataloader is not None:
|
||||||
logger.info('valid dataloader has {} iters'.format(
|
logger.info('valid dataloader has {} iters'.format(
|
||||||
|
|
Loading…
Reference in New Issue