merge dygraph
This commit is contained in:
commit
a739abab57
|
@ -398,6 +398,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
|
||||
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
|
||||
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
|
||||
showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
|
||||
|
||||
zoom = QWidgetAction(self)
|
||||
zoom.setDefaultWidget(self.zoomWidget)
|
||||
|
@ -565,7 +566,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
addActions(self.menus.file,
|
||||
(opendir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||
|
||||
addActions(self.menus.help, (showSteps, showInfo))
|
||||
addActions(self.menus.help, (showKeys,showSteps, showInfo))
|
||||
addActions(self.menus.view, (
|
||||
self.displayLabelOption, self.labelDialogOption,
|
||||
None,
|
||||
|
@ -760,6 +761,10 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
msg = stepsInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def showKeysDialog(self):
|
||||
msg = keysInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def createShape(self):
|
||||
assert self.beginner()
|
||||
self.canvas.setEditing(False)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -174,6 +174,7 @@ def stepsInfo(lang='en'):
|
|||
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
|
||||
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
|
||||
"识别标签保存在*rec_gt.txt*中。\n"
|
||||
|
||||
else:
|
||||
msg = "1. Build and launch using the instructions above.\n" \
|
||||
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
|
||||
|
@ -187,5 +188,57 @@ def stepsInfo(lang='en'):
|
|||
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\
|
||||
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
|
||||
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
|
||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||
|
||||
return msg
|
||||
|
||||
def keysInfo(lang='en'):
|
||||
if lang == 'ch':
|
||||
msg = "快捷键\t\t\t说明\n" \
|
||||
"———————————————————————\n"\
|
||||
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
|
||||
"W\t\t\t新建矩形框\n" \
|
||||
"Q\t\t\t新建四点框\n" \
|
||||
"Ctrl + E\t\t编辑所选框标签\n" \
|
||||
"Ctrl + R\t\t重新识别所选标记\n" \
|
||||
"Ctrl + C\t\t复制并粘贴选中的标记框\n" \
|
||||
"Ctrl + 鼠标左键\t\t多选标记框\n" \
|
||||
"Backspace\t\t删除所选框\n" \
|
||||
"Ctrl + V\t\t确认本张图片标记\n" \
|
||||
"Ctrl + Shift + d\t删除本张图片\n" \
|
||||
"D\t\t\t下一张图片\n" \
|
||||
"A\t\t\t上一张图片\n" \
|
||||
"Ctrl++\t\t\t缩小\n" \
|
||||
"Ctrl--\t\t\t放大\n" \
|
||||
"↑→↓←\t\t\t移动标记框\n" \
|
||||
"———————————————————————\n" \
|
||||
"注:Mac用户Command键替换上述Ctrl键"
|
||||
|
||||
else:
|
||||
msg = "Shortcut Keys\t\tDescription\n" \
|
||||
"———————————————————————\n" \
|
||||
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
|
||||
"\t\t\tof the current image\n" \
|
||||
"\n"\
|
||||
"W\t\t\tCreate a rect box\n" \
|
||||
"Q\t\t\tCreate a four-points box\n" \
|
||||
"Ctrl + E\t\tEdit label of the selected box\n" \
|
||||
"Ctrl + R\t\tRe-recognize the selected box\n" \
|
||||
"Ctrl + C\t\tCopy and paste the selected\n" \
|
||||
"\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Ctrl + Left Mouse\tMulti select the label\n" \
|
||||
"Button\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Backspace\t\tDelete the selected box\n" \
|
||||
"Ctrl + V\t\tCheck image\n" \
|
||||
"Ctrl + Shift + d\tDelete image\n" \
|
||||
"D\t\t\tNext image\n" \
|
||||
"A\t\t\tPrevious image\n" \
|
||||
"Ctrl++\t\t\tZoom in\n" \
|
||||
"Ctrl--\t\t\tZoom out\n" \
|
||||
"↑→↓←\t\t\tMove selected box" \
|
||||
"———————————————————————\n" \
|
||||
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
|
||||
|
||||
return msg
|
|
@ -89,6 +89,7 @@ saveRec=保存识别结果
|
|||
tempLabel=待识别
|
||||
nullLabel=无法识别
|
||||
steps=操作步骤
|
||||
keys=快捷键
|
||||
choseModelLg=选择模型语言
|
||||
cancel=取消
|
||||
ok=确认
|
||||
|
|
|
@ -89,6 +89,7 @@ saveRec=Save Recognition Result
|
|||
tempLabel=TEMPORARY
|
||||
nullLabel=NULL
|
||||
steps=Steps
|
||||
keys=Shortcut Keys
|
||||
choseModelLg=Choose Model Language
|
||||
cancel=Cancel
|
||||
ok=OK
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- ["Student2", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2", "Teacher"]
|
||||
# key: maps
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,174 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,176 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
|
|||
|
||||
set(DEMO_NAME "ocr_system")
|
||||
|
||||
|
||||
macro(safe_set_static_flag)
|
||||
foreach(flag_var
|
||||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
|
||||
|
|
|
@ -93,3 +93,5 @@ cd D:\projects\PaddleOCR\deploy\cpp_infer\out\build\x64-Release
|
|||
|
||||
### 注意
|
||||
* 在Windows下的终端中执行文件exe时,可能会发生乱码的现象,此时需要在终端中输入`CHCP 65001`,将终端的编码方式由GBK编码(默认)改为UTF-8编码,更加具体的解释可以参考这篇博客:[https://blog.csdn.net/qq_35038153/article/details/78430359](https://blog.csdn.net/qq_35038153/article/details/78430359)。
|
||||
|
||||
* 编译时,如果报错`错误:C1083 无法打开包括文件:"dirent.h":No such file or directory`,可以参考该[文档](https://blog.csdn.net/Dora_blank/article/details/117740837#41_C1083_direnthNo_such_file_or_directory_54),新建`dirent.h`文件,并添加到`VC++`的包含目录中。
|
||||
|
|
|
@ -18,6 +18,7 @@ PaddleOCR模型部署。
|
|||
* 首先需要从opencv官网上下载在Linux环境下源码编译的包,以opencv3.4.7为例,下载命令如下。
|
||||
|
||||
```
|
||||
cd deploy/cpp_infer
|
||||
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
|
||||
tar -xf 3.4.7.tar.gz
|
||||
```
|
||||
|
@ -184,7 +185,7 @@ cmake .. \
|
|||
make -j
|
||||
```
|
||||
|
||||
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`;`CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`。
|
||||
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`;`CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`。**注意**:以上路径都写绝对路径,不要写相对路径。
|
||||
|
||||
|
||||
* 编译完成之后,会在`build`文件夹下生成一个名为`ocr_system`的可执行文件。
|
||||
|
|
|
@ -18,6 +18,7 @@ PaddleOCR model deployment.
|
|||
* First of all, you need to download the source code compiled package in the Linux environment from the opencv official website. Taking opencv3.4.7 as an example, the download command is as follows.
|
||||
|
||||
```
|
||||
cd deploy/cpp_infer
|
||||
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
|
||||
tar -xf 3.4.7.tar.gz
|
||||
```
|
||||
|
|
|
@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
|
|||
//------------------------------------------------------------------------------
|
||||
|
||||
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
||||
std::memset(e, 0, sizeof(TEdge));
|
||||
std::memset(e, int(0), sizeof(TEdge));
|
||||
e->Next = eNext;
|
||||
e->Prev = ePrev;
|
||||
e->Curr = Pt;
|
||||
|
@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
|
|||
TEdge *rb = lm->RightBound;
|
||||
|
||||
OutPt *Op1 = 0;
|
||||
if (!lb) {
|
||||
if (!lb || !rb) {
|
||||
// nb: don't insert LB into either AEL or SEL
|
||||
InsertEdgeIntoAEL(rb, 0);
|
||||
SetWindingCount(*rb);
|
||||
if (IsContributing(*rb))
|
||||
Op1 = AddOutPt(rb, rb->Bot);
|
||||
} else if (!rb) {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
SetWindingCount(*lb);
|
||||
if (IsContributing(*lb))
|
||||
Op1 = AddOutPt(lb, lb->Bot);
|
||||
//} else if (!rb) {
|
||||
// InsertEdgeIntoAEL(lb, 0);
|
||||
// SetWindingCount(*lb);
|
||||
// if (IsContributing(*lb))
|
||||
// Op1 = AddOutPt(lb, lb->Bot);
|
||||
InsertScanbeam(lb->Top.Y);
|
||||
} else {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
|
@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
if (dir == dLeftToRight) {
|
||||
maxIt = m_Maxima.begin();
|
||||
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
||||
maxIt = m_Maxima.end();
|
||||
} else {
|
||||
maxRit = m_Maxima.rbegin();
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
||||
maxRit = m_Maxima.rend();
|
||||
}
|
||||
|
@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
}
|
||||
} else {
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -21,10 +21,10 @@ std::vector<std::string> OCRConfig::split(const std::string &str,
|
|||
std::vector<std::string> res;
|
||||
if ("" == str)
|
||||
return res;
|
||||
char *strs = new char[str.length() + 1];
|
||||
char strs[str.length() + 1];
|
||||
std::strcpy(strs, str.c_str());
|
||||
|
||||
char *d = new char[delim.length() + 1];
|
||||
char d[delim.length() + 1];
|
||||
std::strcpy(d, delim.c_str());
|
||||
|
||||
char *p = std::strtok(strs, d);
|
||||
|
@ -61,4 +61,4 @@ void OCRConfig::PrintConfigInfo() {
|
|||
std::cout << "=======End of Paddle OCR inference config======" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -29,7 +29,8 @@ deploy/hubserving/ocr_system/
|
|||
### 1. 准备环境
|
||||
```shell
|
||||
# 安装paddlehub
|
||||
pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# paddlehub 需要 python>3.6.2
|
||||
pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
### 2. 下载推理模型
|
||||
|
|
|
@ -30,7 +30,8 @@ The following steps take the 2-stage series service as an example. If only the d
|
|||
### 1. Prepare the environment
|
||||
```shell
|
||||
# Install paddlehub
|
||||
pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# python>3.6.2 is required bt paddlehub
|
||||
pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
### 2. Download inference model
|
||||
|
|
|
@ -101,7 +101,7 @@ def main():
|
|||
quanter = QAT(config=quant_config)
|
||||
quanter.quantize(model)
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
model.eval()
|
||||
|
||||
# build metric
|
||||
|
@ -113,7 +113,7 @@ def main():
|
|||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
# start eval
|
||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
|
||||
logger.info('metric eval ***************')
|
||||
|
|
|
@ -18,9 +18,9 @@ PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支
|
|||
|
||||
```
|
||||
# 将官网下载的标签文件转换为 train_icdar2015_label.txt
|
||||
python gen_label.py --mode="det" --root_path="icdar_c4_train_imgs/" \
|
||||
--input_path="ch4_training_localization_transcription_gt" \
|
||||
--output_label="train_icdar2015_label.txt"
|
||||
python gen_label.py --mode="det" --root_path="/path/to/icdar_c4_train_imgs/" \
|
||||
--input_path="/path/to/ch4_training_localization_transcription_gt" \
|
||||
--output_label="/path/to/train_icdar2015_label.txt"
|
||||
```
|
||||
|
||||
解压数据集和下载标注文件后,PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
|
||||
|
|
|
@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
|
|||
|
||||
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
如果想使用CPU进行预测,执行命令如下
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB文本检测模型推理"></a>
|
||||
|
@ -221,7 +221,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Gl
|
|||
|
||||
```
|
||||
|
||||
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,**可以执行如下命令:
|
||||
SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
```
|
||||
|
|
|
@ -330,6 +330,8 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
|
|||
|
||||
```
|
||||
|
||||
意大利文由拉丁字母组成,因此执行完命令后会得到名为 rec_latin_lite_train.yml 的配置文件。
|
||||
|
||||
2. 手动修改配置文件
|
||||
|
||||
您也可以手动修改模版中的以下几个字段:
|
||||
|
@ -375,7 +377,9 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
|
|||
|
||||
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
|
||||
|
||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
|
||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以通过下面两种方式下载。
|
||||
* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA)。提取码:frgi。
|
||||
* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
|
||||
|
||||
如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
|
|||
|
||||
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
If you want to use the CPU for prediction, execute the command as follows
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB_DETECTION"></a>
|
||||
|
@ -230,7 +230,7 @@ First, convert the model saved in the SAST text detection training process into
|
|||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
|
||||
```
|
||||
|
||||
**For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command:
|
||||
For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`, run the following command:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
|
|
|
@ -329,6 +329,7 @@ There are two ways to create the required configuration file::
|
|||
...
|
||||
|
||||
```
|
||||
Italian is made up of Latin letters, so after executing the command, you will get the rec_latin_lite_train.yml.
|
||||
|
||||
2. Manually modify the configuration file
|
||||
|
||||
|
@ -375,7 +376,9 @@ Currently, the multi-language algorithms supported by PaddleOCR are:
|
|||
|
||||
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
|
||||
|
||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
|
||||
* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||
* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
|
||||
|
||||
If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file:
|
||||
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.5.30 Provide Lightweight Chinese OCR online experience
|
||||
- 2020.5.30 Model prediction and training support on Windows system
|
||||
- 2020.5.30 Open source general Chinese OCR model
|
||||
|
|
BIN
doc/joinus.PNG
BIN
doc/joinus.PNG
Binary file not shown.
Before Width: | Height: | Size: 205 KiB After Width: | Height: | Size: 189 KiB |
|
@ -14,7 +14,6 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from paddle.io import Dataset
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
@ -46,7 +45,6 @@ class SimpleDataSet(Dataset):
|
|||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.check_data()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if self.mode == "train" and self.do_shuffle:
|
||||
self.shuffle_data_random()
|
||||
|
@ -103,18 +101,25 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data = self.data_lines[file_idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data['ext_data'] = self.get_ext_data()
|
||||
outs = transform(data, self.ops)
|
||||
except:
|
||||
error_meg = traceback.format_exc()
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing file {} and label {}, error happened with msg: {}".format(
|
||||
data['img_path'],data['label'], error_meg))
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
|
@ -125,17 +130,3 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
|
||||
def check_data(self):
|
||||
new_data_lines = []
|
||||
for data_line in self.data_lines:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
if os.path.exists(img_path):
|
||||
new_data_lines.append({'img_path': img_path, 'label': label})
|
||||
else:
|
||||
self.logger.info("{} does not exist!".format(img_path))
|
||||
self.data_lines = new_data_lines
|
|
@ -54,6 +54,27 @@ class CELoss(nn.Layer):
|
|||
return loss
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1,2])
|
||||
elif reduction=="none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1,2])
|
||||
|
||||
return loss
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
|
@ -69,17 +90,21 @@ class DMLLoss(nn.Layer):
|
|||
self.act = nn.Sigmoid()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
if len(out1.shape) < 2:
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
else:
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import paddle.nn as nn
|
|||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
|
@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
|
|||
|
||||
def forward(self, input, batch, **kargs):
|
||||
loss_dict = {}
|
||||
loss_all = 0.
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
weight = self.loss_weight[idx]
|
||||
loss = {
|
||||
"{}_{}".format(key, idx): loss[key] * weight
|
||||
for key in loss
|
||||
}
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
||||
|
|
|
@ -14,23 +14,76 @@
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
else:
|
||||
loss_dict["loss"] = 0.
|
||||
for k, value in loss_dict.items():
|
||||
if k == "loss":
|
||||
continue
|
||||
else:
|
||||
loss_dict["loss"] += value
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
||||
name="loss_dml"):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
key=None,
|
||||
maps_name=None,
|
||||
name="dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss):
|
|||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
|
||||
if self.maps_name is None:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||
idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
|
|||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="db",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.name = name
|
||||
self.key = None
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
|
||||
if isinstance(loss, dict):
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
continue
|
||||
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||
loss_dict[name] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDilaDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="dila_dbloss"):
|
||||
super().__init__()
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
stu_outs = predicts[pair[0]]
|
||||
tch_outs = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
stu_preds = stu_outs[self.key]
|
||||
tch_preds = tch_outs[self.key]
|
||||
|
||||
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||
|
||||
# dilation to teacher prediction
|
||||
dilation_w = np.array([[1, 1], [1, 1]])
|
||||
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||
for i in range(th_shrink_maps.shape[0]):
|
||||
dilate_maps[i] = cv2.dilate(
|
||||
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||
1:]
|
||||
|
||||
# calculate the shrink map loss
|
||||
bce_loss = self.alpha * self.bce_loss(
|
||||
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||
label_shrink_mask)
|
||||
|
||||
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||
loss_dict[k] = bce_loss + loss_binary_maps
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
|
|
@ -55,6 +55,7 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
|
|
|
@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
|
|||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name="RecMetric",
|
||||
main_indicator='acc',
|
||||
base_metric_name=None,
|
||||
main_indicator=None,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
|
@ -42,16 +42,13 @@ class DistillationMetric(object):
|
|||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, *args, **kwargs):
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
|
|
|
@ -79,7 +79,10 @@ class BaseModel(nn.Layer):
|
|||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
x = self.head(x, targets=data)
|
||||
y["head_out"] = x
|
||||
if isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
else:
|
||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
|||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
init_model(model, path=pretrained)
|
||||
load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
|
|
|
@ -21,7 +21,7 @@ import copy
|
|||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
|
@ -33,9 +33,10 @@ from .pse_postprocess import PSEPostProcess
|
|||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'DBPostProcess','PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode', 'PSEPostProcess'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -187,3 +187,29 @@ class DBPostProcess(object):
|
|||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class DistillationDBPostProcess(object):
|
||||
def __init__(self, model_name=["student"],
|
||||
key=None,
|
||||
thresh=0.3,
|
||||
box_thresh=0.6,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.5,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
self.post_process = DBPostProcess(thresh=thresh,
|
||||
box_thresh=box_thresh,
|
||||
max_candidates=max_candidates,
|
||||
unclip_ratio=unclip_ratio,
|
||||
use_dilation=use_dilation,
|
||||
score_mode=score_mode)
|
||||
|
||||
def __call__(self, predicts, shape_list):
|
||||
results = {}
|
||||
for k in self.model_name:
|
||||
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
||||
return results
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
|
@ -31,7 +31,9 @@ def gen_det_label(root_path, input_dir, out_label):
|
|||
for label_file in os.listdir(input_dir):
|
||||
img_path = root_path + label_file[3:-4] + ".jpg"
|
||||
label = []
|
||||
with open(os.path.join(input_dir, label_file), 'r') as f:
|
||||
with open(
|
||||
os.path.join(input_dir, label_file), 'r',
|
||||
encoding='utf-8-sig') as f:
|
||||
for line in f.readlines():
|
||||
tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
|
||||
"").split(',')
|
||||
|
|
|
@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer):
|
|||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
if path is None:
|
||||
return False
|
||||
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||
print(f"The pretrained_model {path} does not exists!")
|
||||
return False
|
||||
|
||||
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||
params = paddle.load(path)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
print(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
print(f"load pretrain successful from {path}")
|
||||
return model
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
model_name:ocr_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.auto_cast:False
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:10
|
||||
Global.save_model_dir:./output/
|
||||
Global.save_inference_dir:./output/
|
||||
Train.loader.batch_size_per_card:
|
||||
Global.use_gpu
|
||||
Global.pretrained_model
|
||||
Global.use_gpu:
|
||||
Global.pretrained_model:null
|
||||
|
||||
trainer:norm|pact
|
||||
norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
|
@ -17,6 +16,8 @@ distill_train:null
|
|||
|
||||
eval:tools/eval.py -c configs/det/det_mv3_db.yml -o
|
||||
|
||||
Global.save_inference_dir:./output/
|
||||
Global.pretrained_model:
|
||||
norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o
|
||||
quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o
|
||||
fpgm_export:deploy/slim/prune/export_prune_model.py
|
||||
|
@ -29,7 +30,6 @@ inference:tools/infer/predict_det.py
|
|||
--rec_batch_num:1
|
||||
--use_tensorrt:True|False
|
||||
--precision:fp32|fp16|int8
|
||||
--det_model_dir
|
||||
--image_dir
|
||||
--save_log_path
|
||||
|
||||
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
|
||||
--image_dir:./inference/ch_det_data_50/all-sum-510/
|
||||
--save_log_path:./test/output/
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
model_name:ocr_rec
|
||||
python:python
|
||||
gpu_list:0|0,1
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:10
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:
|
||||
Global.use_gpu:
|
||||
Global.pretrained_model:null
|
||||
|
||||
trainer:norm|pact
|
||||
norm_train:tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
quant_train:deploy/slim/quantization/quant.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
|
||||
eval:tools/eval.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||
|
||||
Global.save_inference_dir:./output/
|
||||
Global.pretrained_model:
|
||||
norm_export:tools/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||
quant_export:deploy/slim/quantization/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
|
||||
inference:tools/infer/predict_rec.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:True|False
|
||||
--precision:fp32|fp16|int8
|
||||
--rec_model_dir:./inference/ch_ppocr_mobile_v2.0_rec_infer/
|
||||
--image_dir:./inference/rec_inference
|
||||
--save_log_path:./test/output/
|
|
@ -26,20 +26,24 @@ IFS=$'\n'
|
|||
# The training params
|
||||
model_name=$(func_parser_value "${lines[0]}")
|
||||
train_model_list=$(func_parser_value "${lines[0]}")
|
||||
|
||||
trainer_list=$(func_parser_value "${lines[10]}")
|
||||
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
|
||||
MODE=$2
|
||||
# prepare pretrained weights and dataset
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||
|
||||
# prepare pretrained weights and dataset
|
||||
if [ ${train_model_list[*]} = "ocr_det" ]; then
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||
fi
|
||||
if [ ${MODE} = "lite_train_infer" ];then
|
||||
# pretrain lite train data
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
|
||||
cd ./train_data/ && tar xf icdar2015_lite.tar
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
|
||||
|
||||
cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
|
||||
ln -s ./icdar2015_lite ./icdar2015
|
||||
cd ../
|
||||
epoch=10
|
||||
|
@ -47,13 +51,15 @@ if [ ${MODE} = "lite_train_infer" ];then
|
|||
elif [ ${MODE} = "whole_train_infer" ];then
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
|
||||
cd ./train_data/ && tar xf icdar2015.tar && cd ../
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
|
||||
cd ./train_data/ && tar xf icdar2015.tar && tar xf ic15_data.tar && cd ../
|
||||
epoch=500
|
||||
eval_batch_step=200
|
||||
elif [ ${MODE} = "whole_infer" ];then
|
||||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
|
||||
cd ./train_data/ && tar xf icdar2015_infer.tar
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
|
||||
cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar
|
||||
ln -s ./icdar2015_infer ./icdar2015
|
||||
cd ../
|
||||
epoch=10
|
||||
|
@ -62,8 +68,8 @@ else
|
|||
rm -rf ./train_data/icdar2015
|
||||
wget -nc -P ./train_data https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_infer"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
else
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_train"
|
||||
|
@ -86,15 +92,17 @@ for train_model in ${train_model_list[*]}; do
|
|||
elif [ ${train_model} = "ocr_rec" ];then
|
||||
model_name="ocr_rec"
|
||||
yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_rec_data_200.tar
|
||||
cd ./inference && tar xf ch_rec_data_200.tar && cd ../
|
||||
img_dir="./inference/ch_rec_data_200/"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar
|
||||
cd ./inference && tar xf rec_inference.tar && cd ../
|
||||
img_dir="./inference/rec_inference/"
|
||||
data_dir=./inference/rec_inference
|
||||
data_label_file=[./inference/rec_inference/rec_gt_test.txt]
|
||||
fi
|
||||
|
||||
# eval
|
||||
for slim_trainer in ${trainer_list[*]}; do
|
||||
if [ ${slim_trainer} = "norm" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
if [ ${model_name} = "det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
|
@ -104,7 +112,7 @@ for train_model in ${train_model_list[*]}; do
|
|||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
elif [ ${slim_trainer} = "pact" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
if [ ${model_name} = "det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_quant_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_quant_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
|
@ -114,7 +122,7 @@ for train_model in ${train_model_list[*]}; do
|
|||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
elif [ ${slim_trainer} = "distill" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
if [ ${model_name} = "det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_distill_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_distill_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
|
@ -124,7 +132,7 @@ for train_model in ${train_model_list[*]}; do
|
|||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
fi
|
||||
elif [ ${slim_trainer} = "fpgm" ]; then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
if [ ${model_name} = "det" ]; then
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_det_prune_train"
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
|
||||
|
|
144
test/test.sh
144
test/test.sh
|
@ -41,59 +41,51 @@ gpu_list=$(func_parser_value "${lines[2]}")
|
|||
autocast_list=$(func_parser_value "${lines[3]}")
|
||||
autocast_key=$(func_parser_key "${lines[3]}")
|
||||
epoch_key=$(func_parser_key "${lines[4]}")
|
||||
epoch_num=$(func_parser_value "${lines[4]}")
|
||||
save_model_key=$(func_parser_key "${lines[5]}")
|
||||
save_infer_key=$(func_parser_key "${lines[6]}")
|
||||
train_batch_key=$(func_parser_key "${lines[7]}")
|
||||
train_use_gpu_key=$(func_parser_key "${lines[8]}")
|
||||
pretrain_model_key=$(func_parser_key "${lines[9]}")
|
||||
train_batch_key=$(func_parser_key "${lines[6]}")
|
||||
train_use_gpu_key=$(func_parser_key "${lines[7]}")
|
||||
pretrain_model_key=$(func_parser_key "${lines[8]}")
|
||||
pretrain_model_value=$(func_parser_value "${lines[8]}")
|
||||
|
||||
trainer_list=$(func_parser_value "${lines[10]}")
|
||||
norm_trainer=$(func_parser_value "${lines[11]}")
|
||||
pact_trainer=$(func_parser_value "${lines[12]}")
|
||||
fpgm_trainer=$(func_parser_value "${lines[13]}")
|
||||
distill_trainer=$(func_parser_value "${lines[14]}")
|
||||
trainer_list=$(func_parser_value "${lines[9]}")
|
||||
norm_trainer=$(func_parser_value "${lines[10]}")
|
||||
pact_trainer=$(func_parser_value "${lines[11]}")
|
||||
fpgm_trainer=$(func_parser_value "${lines[12]}")
|
||||
distill_trainer=$(func_parser_value "${lines[13]}")
|
||||
|
||||
eval_py=$(func_parser_value "${lines[15]}")
|
||||
norm_export=$(func_parser_value "${lines[16]}")
|
||||
pact_export=$(func_parser_value "${lines[17]}")
|
||||
fpgm_export=$(func_parser_value "${lines[18]}")
|
||||
distill_export=$(func_parser_value "${lines[19]}")
|
||||
eval_py=$(func_parser_value "${lines[14]}")
|
||||
|
||||
inference_py=$(func_parser_value "${lines[20]}")
|
||||
use_gpu_key=$(func_parser_key "${lines[21]}")
|
||||
use_gpu_list=$(func_parser_value "${lines[21]}")
|
||||
use_mkldnn_key=$(func_parser_key "${lines[22]}")
|
||||
use_mkldnn_list=$(func_parser_value "${lines[22]}")
|
||||
cpu_threads_key=$(func_parser_key "${lines[23]}")
|
||||
cpu_threads_list=$(func_parser_value "${lines[23]}")
|
||||
batch_size_key=$(func_parser_key "${lines[24]}")
|
||||
batch_size_list=$(func_parser_value "${lines[24]}")
|
||||
use_trt_key=$(func_parser_key "${lines[25]}")
|
||||
use_trt_list=$(func_parser_value "${lines[25]}")
|
||||
precision_key=$(func_parser_key "${lines[26]}")
|
||||
precision_list=$(func_parser_value "${lines[26]}")
|
||||
model_dir_key=$(func_parser_key "${lines[27]}")
|
||||
image_dir_key=$(func_parser_key "${lines[28]}")
|
||||
save_log_key=$(func_parser_key "${lines[29]}")
|
||||
save_infer_key=$(func_parser_key "${lines[15]}")
|
||||
export_weight=$(func_parser_key "${lines[16]}")
|
||||
norm_export=$(func_parser_value "${lines[17]}")
|
||||
pact_export=$(func_parser_value "${lines[18]}")
|
||||
fpgm_export=$(func_parser_value "${lines[19]}")
|
||||
distill_export=$(func_parser_value "${lines[20]}")
|
||||
|
||||
inference_py=$(func_parser_value "${lines[21]}")
|
||||
use_gpu_key=$(func_parser_key "${lines[22]}")
|
||||
use_gpu_list=$(func_parser_value "${lines[22]}")
|
||||
use_mkldnn_key=$(func_parser_key "${lines[23]}")
|
||||
use_mkldnn_list=$(func_parser_value "${lines[23]}")
|
||||
cpu_threads_key=$(func_parser_key "${lines[24]}")
|
||||
cpu_threads_list=$(func_parser_value "${lines[24]}")
|
||||
batch_size_key=$(func_parser_key "${lines[25]}")
|
||||
batch_size_list=$(func_parser_value "${lines[25]}")
|
||||
use_trt_key=$(func_parser_key "${lines[26]}")
|
||||
use_trt_list=$(func_parser_value "${lines[26]}")
|
||||
precision_key=$(func_parser_key "${lines[27]}")
|
||||
precision_list=$(func_parser_value "${lines[27]}")
|
||||
infer_model_key=$(func_parser_key "${lines[28]}")
|
||||
infer_model=$(func_parser_value "${lines[28]}")
|
||||
image_dir_key=$(func_parser_key "${lines[29]}")
|
||||
infer_img_dir=$(func_parser_value "${lines[29]}")
|
||||
save_log_key=$(func_parser_key "${lines[30]}")
|
||||
|
||||
LOG_PATH="./test/output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results.log"
|
||||
|
||||
if [ ${MODE} = "lite_train_infer" ]; then
|
||||
export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/"
|
||||
export epoch_num=10
|
||||
elif [ ${MODE} = "whole_infer" ]; then
|
||||
export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/"
|
||||
export epoch_num=10
|
||||
elif [ ${MODE} = "whole_train_infer" ]; then
|
||||
export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/"
|
||||
export epoch_num=300
|
||||
else
|
||||
export infer_img_dir="./inference/ch_det_data_50/all-sum-510"
|
||||
export infer_model_dir="./inference/ch_ppocr_mobile_v2.0_det_train/best_accuracy"
|
||||
fi
|
||||
|
||||
|
||||
function func_inference(){
|
||||
IFS='|'
|
||||
|
@ -109,8 +101,8 @@ function func_inference(){
|
|||
for use_mkldnn in ${use_mkldnn_list[*]}; do
|
||||
for threads in ${cpu_threads_list[*]}; do
|
||||
for batch_size in ${batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}"
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${cpu_threads_key}=${threads} ${model_dir_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True"
|
||||
_save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${cpu_threads_key}=${threads} ${infer_model_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True"
|
||||
eval $command
|
||||
status_check $? "${command}" "${status_log}"
|
||||
done
|
||||
|
@ -123,8 +115,8 @@ function func_inference(){
|
|||
continue
|
||||
fi
|
||||
for batch_size in ${batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}"
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_trt_key}=${use_trt} ${precision_key}=${precision} ${model_dir_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True"
|
||||
_save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_trt_key}=${use_trt} ${precision_key}=${precision} ${infer_model_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True"
|
||||
eval $command
|
||||
status_check $? "${command}" "${status_log}"
|
||||
done
|
||||
|
@ -138,12 +130,13 @@ if [ ${MODE} != "infer" ]; then
|
|||
|
||||
IFS="|"
|
||||
for gpu in ${gpu_list[*]}; do
|
||||
train_use_gpu=True
|
||||
use_gpu=True
|
||||
if [ ${gpu} = "-1" ];then
|
||||
train_use_gpu=False
|
||||
use_gpu=False
|
||||
env=""
|
||||
elif [ ${#gpu} -le 1 ];then
|
||||
env="export CUDA_VISIBLE_DEVICES=${gpu}"
|
||||
eval ${env}
|
||||
elif [ ${#gpu} -le 15 ];then
|
||||
IFS=","
|
||||
array=(${gpu})
|
||||
|
@ -155,6 +148,7 @@ for gpu in ${gpu_list[*]}; do
|
|||
ips=${array[0]}
|
||||
gpu=${array[1]}
|
||||
IFS="|"
|
||||
env=" "
|
||||
fi
|
||||
for autocast in ${autocast_list[*]}; do
|
||||
for trainer in ${trainer_list[*]}; do
|
||||
|
@ -179,13 +173,32 @@ for gpu in ${gpu_list[*]}; do
|
|||
continue
|
||||
fi
|
||||
|
||||
save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}"
|
||||
if [ ${#gpu} -le 2 ];then # epoch_num #TODO
|
||||
cmd="${python} ${run_train} ${train_use_gpu_key}=${train_use_gpu} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log} "
|
||||
elif [ ${#gpu} -le 15 ];then
|
||||
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log}"
|
||||
# not set autocast when autocast is null
|
||||
if [ ${autocast} = "null" ]; then
|
||||
set_autocast=" "
|
||||
else
|
||||
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log}"
|
||||
set_autocast="${autocast_key}=${autocast}"
|
||||
fi
|
||||
# not set epoch when whole_train_infer
|
||||
if [ ${MODE} != "whole_train_infer" ]; then
|
||||
set_epoch="${epoch_key}=${epoch_num}"
|
||||
else
|
||||
set_epoch=" "
|
||||
fi
|
||||
# set pretrain
|
||||
if [ ${pretrain_model_value} != "null" ]; then
|
||||
set_pretrain="${pretrain_model_key}=${pretrain_model_value}"
|
||||
else
|
||||
set_pretrain=" "
|
||||
fi
|
||||
|
||||
save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}"
|
||||
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
|
||||
cmd="${python} ${run_train} ${train_use_gpu_key}=${use_gpu} ${save_model_key}=${save_log} ${set_epoch} ${set_pretrain} ${set_autocast}"
|
||||
elif [ ${#gpu} -le 15 ];then # train with multi-gpu
|
||||
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${save_model_key}=${save_log} ${set_epoch} ${set_pretrain} ${set_autocast}"
|
||||
else # train with multi-machine
|
||||
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${save_model_key}=${save_log} ${set_pretrain} ${set_epoch} ${set_autocast}"
|
||||
fi
|
||||
# run train
|
||||
eval $cmd
|
||||
|
@ -198,24 +211,27 @@ for gpu in ${gpu_list[*]}; do
|
|||
|
||||
# run export model
|
||||
save_infer_path="${save_log}"
|
||||
export_cmd="${python} ${run_export} ${save_model_key}=${save_log} ${pretrain_model_key}=${save_log}/latest ${save_infer_key}=${save_infer_path}"
|
||||
export_cmd="${python} ${run_export} ${save_model_key}=${save_log} ${export_weight}=${save_log}/latest ${save_infer_key}=${save_infer_path}"
|
||||
eval $export_cmd
|
||||
status_check $? "${export_cmd}" "${status_log}"
|
||||
|
||||
#run inference
|
||||
eval $env
|
||||
save_infer_path="${save_log}"
|
||||
func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${infer_img_dir}"
|
||||
eval "unset CUDA_VISIBLE_DEVICES"
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
else
|
||||
save_infer_path="${LOG_PATH}/${MODE}"
|
||||
run_export=${norm_export}
|
||||
export_cmd="${python} ${run_export} ${save_model_key}=${save_infer_path} ${pretrain_model_key}=${infer_model_dir} ${save_infer_key}=${save_infer_path}"
|
||||
eval $export_cmd
|
||||
status_check $? "${export_cmd}" "${status_log}"
|
||||
|
||||
GPUID=$3
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
echo $env
|
||||
#run inference
|
||||
func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${infer_img_dir}"
|
||||
func_inference "${python}" "${inference_py}" "${infer_model}" "${LOG_PATH}" "${infer_img_dir}"
|
||||
fi
|
||||
|
|
|
@ -19,7 +19,29 @@
|
|||
|
||||
|
||||
### 2.1 训练
|
||||
TBD
|
||||
#### 数据准备
|
||||
训练数据使用公开数据集[PubTabNet](https://arxiv.org/abs/1911.10683),可以从[官网](https://github.com/ibm-aur-nlp/PubTabNet)下载。PubTabNet数据集包含约50万张表格数据的图像,以及图像对应的html格式的注释。
|
||||
|
||||
#### 启动训练
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
```shell
|
||||
# 单机单卡训练
|
||||
python3 tools/train.py -c configs/table/table_mv3.yml
|
||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/table/table_mv3.yml
|
||||
```
|
||||
|
||||
上述指令中,通过-c 选择训练使用configs/table/table_mv3.yml配置文件。有关配置文件的详细解释,请参考[链接](./config.md)。
|
||||
|
||||
#### 断点训练
|
||||
|
||||
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
|
||||
```shell
|
||||
python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
|
||||
|
||||
|
||||
### 2.2 评估
|
||||
先cd到PaddleOCR/ppstructure目录下
|
||||
|
|
|
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
|||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
@ -55,7 +55,10 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
model_type = None
|
||||
|
||||
best_model_dict = init_model(config, model)
|
||||
if len(best_model_dict):
|
||||
|
@ -68,7 +71,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
eval_class, model_type, use_srn)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -112,7 +112,6 @@ class TextClassifier(object):
|
|||
if '180' in label and score > self.cls_thresh:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
elapse = time.time() - starttime
|
||||
return img_list, cls_res, elapse
|
||||
|
||||
|
||||
|
@ -146,7 +145,6 @@ def main(args):
|
|||
cls_res[ino]))
|
||||
logger.info(
|
||||
"The predict time about text angle classify module is as follows: ")
|
||||
text_classifier.cls_times.info(average=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -175,7 +175,7 @@ class TextDetector(object):
|
|||
|
||||
st = time.time()
|
||||
|
||||
if args.benchmark:
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.start()
|
||||
|
||||
data = transform(data, self.preprocess_op)
|
||||
|
@ -186,7 +186,7 @@ class TextDetector(object):
|
|||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
|
||||
if args.benchmark:
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
|
@ -195,7 +195,7 @@ class TextDetector(object):
|
|||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if args.benchmark:
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
preds = {}
|
||||
|
@ -220,7 +220,7 @@ class TextDetector(object):
|
|||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
if args.benchmark:
|
||||
if self.args.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
et = time.time()
|
||||
return dt_boxes, et - st
|
||||
|
|
|
@ -64,6 +64,24 @@ class TextRecognizer(object):
|
|||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
self.benchmark = args.benchmark
|
||||
if args.benchmark:
|
||||
import auto_log
|
||||
pid = os.getpid()
|
||||
self.autolog = auto_log.AutoLogger(
|
||||
model_name="rec",
|
||||
model_precision=args.precision,
|
||||
batch_size=args.rec_batch_num,
|
||||
data_shape="dynamic",
|
||||
save_path=args.save_log_path,
|
||||
inference_config=self.config,
|
||||
pids=pid,
|
||||
process_name=None,
|
||||
gpu_ids=0 if args.use_gpu else None,
|
||||
time_keys=[
|
||||
'preprocess_time', 'inference_time', 'postprocess_time'
|
||||
],
|
||||
warmup=10)
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
|
@ -168,6 +186,8 @@ class TextRecognizer(object):
|
|||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
if self.benchmark:
|
||||
self.autolog.times.start()
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
|
@ -196,6 +216,8 @@ class TextRecognizer(object):
|
|||
norm_img_batch.append(norm_img[0])
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
|
||||
if self.rec_algorithm == "SRN":
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||
|
@ -222,6 +244,8 @@ class TextRecognizer(object):
|
|||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = {"predict": outputs[2]}
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
|
@ -231,11 +255,14 @@ class TextRecognizer(object):
|
|||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
|
||||
if self.benchmark:
|
||||
self.autolog.times.end(stamp=True)
|
||||
return rec_res, time.time() - st
|
||||
|
||||
|
||||
|
@ -251,9 +278,6 @@ def main(args):
|
|||
for i in range(10):
|
||||
res = text_recognizer([img])
|
||||
|
||||
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
|
||||
count = 0
|
||||
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
|
@ -273,6 +297,8 @@ def main(args):
|
|||
for ino in range(len(img_list)):
|
||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||
rec_res[ino]))
|
||||
if args.benchmark:
|
||||
text_recognizer.autolog.report()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -174,8 +174,6 @@ def main(args):
|
|||
logger.info("The predict total time is {}".format(time.time() - _st))
|
||||
logger.info("\nThe predict total time is {}".format(total_time))
|
||||
|
||||
img_num = text_sys.text_detector.det_times.img_num
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
|
|
|
@ -34,7 +34,7 @@ def init_args():
|
|||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--min_subgraph_size", type=int, default=3)
|
||||
parser.add_argument("--min_subgraph_size", type=int, default=10)
|
||||
parser.add_argument("--precision", type=str, default="fp32")
|
||||
parser.add_argument("--gpu_mem", type=int, default=500)
|
||||
|
||||
|
@ -161,7 +161,7 @@ def create_predictor(args, mode, logger):
|
|||
config.enable_use_gpu(args.gpu_mem, 0)
|
||||
if args.use_tensorrt:
|
||||
config.enable_tensorrt_engine(
|
||||
precision_mode=inference.PrecisionType.Float32,
|
||||
precision_mode=precision,
|
||||
max_batch_size=args.max_batch_size,
|
||||
min_subgraph_size=args.min_subgraph_size)
|
||||
# skip the minmum trt subgraph
|
||||
|
|
|
@ -186,7 +186,10 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
model_type = None
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
|
|
@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer):
|
|||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
logger.info('valid dataloader has {} iters'.format(
|
||||
|
|
Loading…
Reference in New Issue